feat: add quotas (#4779)

adds possibilities to cap authenticated requests and execution seconds of actions on a defined intervall
This commit is contained in:
Elio Bischof
2023-02-15 02:52:11 +01:00
committed by GitHub
parent 45f6a4436e
commit 681541f41b
117 changed files with 4652 additions and 510 deletions

View File

@@ -16,6 +16,7 @@ import (
http_util "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/ui/login"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/metrics"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
@@ -36,7 +37,18 @@ type health interface {
Instance(ctx context.Context, shouldTriggerBulk bool) (*query.Instance, error)
}
func New(port uint16, router *mux.Router, queries *query.Queries, verifier *internal_authz.TokenVerifier, authZ internal_authz.Config, externalSecure bool, tlsConfig *tls.Config, http2HostName, http1HostName string) *API {
func New(
port uint16,
router *mux.Router,
queries *query.Queries,
verifier *internal_authz.TokenVerifier,
authZ internal_authz.Config,
externalSecure bool,
tlsConfig *tls.Config,
http2HostName,
http1HostName string,
accessSvc *logstore.Service,
) *API {
api := &API{
port: port,
verifier: verifier,
@@ -45,7 +57,8 @@ func New(port uint16, router *mux.Router, queries *query.Queries, verifier *inte
externalSecure: externalSecure,
http1HostName: http1HostName,
}
api.grpcServer = server.CreateServer(api.verifier, authZ, queries, http2HostName, tlsConfig)
api.grpcServer = server.CreateServer(api.verifier, authZ, queries, http2HostName, tlsConfig, accessSvc)
api.routeGRPC()
api.RegisterHandler("/debug", api.healthHandler())

View File

@@ -82,7 +82,7 @@ func DefaultErrorHandler(w http.ResponseWriter, r *http.Request, err error, code
http.Error(w, err.Error(), code)
}
func NewHandler(commands *command.Commands, verifier *authz.TokenVerifier, authConfig authz.Config, idGenerator id.Generator, storage static.Storage, queries *query.Queries, instanceInterceptor, assetCacheInterceptor func(handler http.Handler) http.Handler) http.Handler {
func NewHandler(commands *command.Commands, verifier *authz.TokenVerifier, authConfig authz.Config, idGenerator id.Generator, storage static.Storage, queries *query.Queries, instanceInterceptor, assetCacheInterceptor, accessInterceptor func(handler http.Handler) http.Handler) http.Handler {
h := &Handler{
commands: commands,
errorHandler: DefaultErrorHandler,
@@ -94,7 +94,7 @@ func NewHandler(commands *command.Commands, verifier *authz.TokenVerifier, authC
verifier.RegisterServer("Assets-API", "assets", AssetsService_AuthMethods)
router := mux.NewRouter()
router.Use(instanceInterceptor, assetCacheInterceptor)
router.Use(instanceInterceptor, assetCacheInterceptor, accessInterceptor)
RegisterRoutes(router, h)
router.PathPrefix("/{owner}").Methods("GET").HandlerFunc(DownloadHandleFunc(h, h.GetFile()))
return http_util.CopyHeadersToContext(http_mw.CORSInterceptor(router))

View File

@@ -23,7 +23,8 @@ type Instance interface {
}
type InstanceVerifier interface {
InstanceByHost(context.Context, string) (Instance, error)
InstanceByHost(ctx context.Context, host string) (Instance, error)
InstanceByID(ctx context.Context) (Instance, error)
}
type instance struct {

View File

@@ -55,6 +55,8 @@ func ExtractCaosError(err error) (c codes.Code, msg, id string, ok bool) {
return codes.Unavailable, caosErr.GetMessage(), caosErr.GetID(), true
case *caos_errs.UnimplementedError:
return codes.Unimplemented, caosErr.GetMessage(), caosErr.GetID(), true
case *caos_errs.ResourceExhaustedError:
return codes.ResourceExhausted, caosErr.GetMessage(), caosErr.GetID(), true
default:
return codes.Unknown, err.Error(), "", false
}

View File

@@ -136,6 +136,14 @@ func Test_Extract(t *testing.T) {
"id",
true,
},
{
"exhausted",
args{caos_errs.ThrowResourceExhausted(nil, "id", "exhausted")},
codes.ResourceExhausted,
"exhausted",
"id",
true,
},
{
"unknown",
args{errors.New("unknown")},

View File

@@ -0,0 +1,55 @@
package middleware
import (
"context"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/logstore/emitters/access"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
func AccessStorageInterceptor(svc *logstore.Service) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) {
if !svc.Enabled() {
return handler(ctx, req)
}
reqMd, _ := metadata.FromIncomingContext(ctx)
resp, handlerErr := handler(ctx, req)
interceptorCtx, span := tracing.NewServerInterceptorSpan(ctx)
defer func() { span.EndWithError(err) }()
var respStatus uint32
grpcStatus, ok := status.FromError(handlerErr)
if ok {
respStatus = uint32(grpcStatus.Code())
}
resMd, _ := metadata.FromOutgoingContext(ctx)
instance := authz.GetInstance(ctx)
record := &access.Record{
LogDate: time.Now(),
Protocol: access.GRPC,
RequestURL: info.FullMethod,
ResponseStatus: respStatus,
RequestHeaders: reqMd,
ResponseHeaders: resMd,
InstanceID: instance.InstanceID(),
ProjectID: instance.ProjectID(),
RequestedDomain: instance.RequestedDomain(),
RequestedHost: instance.RequestedHost(),
}
svc.Handle(interceptorCtx, record)
return resp, handlerErr
}
}

View File

@@ -2,7 +2,7 @@ package middleware
import (
"context"
"errors"
errs "errors"
"fmt"
"strings"
@@ -14,7 +14,7 @@ import (
"google.golang.org/grpc/status"
"github.com/zitadel/zitadel/internal/api/authz"
caos_errors "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/i18n"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
@@ -23,27 +23,36 @@ const (
HTTP1Host = "x-zitadel-http1-host"
)
type InstanceVerifier interface {
GetInstance(ctx context.Context)
}
func InstanceInterceptor(verifier authz.InstanceVerifier, headerName string, ignoredServices ...string) grpc.UnaryServerInterceptor {
func InstanceInterceptor(verifier authz.InstanceVerifier, headerName string, explicitInstanceIdServices ...string) grpc.UnaryServerInterceptor {
translator, err := newZitadelTranslator(language.English)
logging.OnError(err).Panic("unable to get translator")
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
return setInstance(ctx, req, info, handler, verifier, headerName, translator, ignoredServices...)
return setInstance(ctx, req, info, handler, verifier, headerName, translator, explicitInstanceIdServices...)
}
}
func setInstance(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, verifier authz.InstanceVerifier, headerName string, translator *i18n.Translator, ignoredServices ...string) (_ interface{}, err error) {
func setInstance(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, verifier authz.InstanceVerifier, headerName string, translator *i18n.Translator, idFromRequestsServices ...string) (_ interface{}, err error) {
interceptorCtx, span := tracing.NewServerInterceptorSpan(ctx)
defer func() { span.EndWithError(err) }()
for _, service := range ignoredServices {
for _, service := range idFromRequestsServices {
if !strings.HasPrefix(service, "/") {
service = "/" + service
}
if strings.HasPrefix(info.FullMethod, service) {
return handler(ctx, req)
withInstanceIDProperty, ok := req.(interface{ GetInstanceId() string })
if !ok {
return handler(ctx, req)
}
ctx = authz.WithInstanceID(ctx, withInstanceIDProperty.GetInstanceId())
instance, err := verifier.InstanceByID(ctx)
if err != nil {
notFoundErr := new(errors.NotFoundError)
if errs.As(err, &notFoundErr) {
notFoundErr.Message = translator.LocalizeFromCtx(ctx, notFoundErr.GetMessage(), nil)
}
return nil, status.Error(codes.NotFound, err.Error())
}
return handler(authz.WithInstance(ctx, instance), req)
}
}
@@ -53,9 +62,9 @@ func setInstance(ctx context.Context, req interface{}, info *grpc.UnaryServerInf
}
instance, err := verifier.InstanceByHost(interceptorCtx, host)
if err != nil {
caosErr := new(caos_errors.NotFoundError)
if errors.As(err, &caosErr) {
caosErr.Message = translator.LocalizeFromCtx(ctx, caosErr.GetMessage(), nil)
notFoundErr := new(errors.NotFoundError)
if errs.As(err, &notFoundErr) {
notFoundErr.Message = translator.LocalizeFromCtx(ctx, notFoundErr.GetMessage(), nil)
}
return nil, status.Error(codes.NotFound, err.Error())
}

View File

@@ -153,13 +153,15 @@ type mockInstanceVerifier struct {
host string
}
func (m *mockInstanceVerifier) InstanceByHost(ctx context.Context, host string) (authz.Instance, error) {
func (m *mockInstanceVerifier) InstanceByHost(_ context.Context, host string) (authz.Instance, error) {
if host != m.host {
return nil, fmt.Errorf("invalid host")
}
return &mockInstance{}, nil
}
func (m *mockInstanceVerifier) InstanceByID(context.Context) (authz.Instance, error) { return nil, nil }
type mockInstance struct{}
func (m *mockInstance) InstanceID() string {

View File

@@ -0,0 +1,46 @@
package middleware
import (
"context"
"strings"
"google.golang.org/grpc"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
func QuotaExhaustedInterceptor(svc *logstore.Service, ignoreService ...string) grpc.UnaryServerInterceptor {
prunedIgnoredServices := make([]string, len(ignoreService))
for idx, service := range ignoreService {
if !strings.HasPrefix(service, "/") {
service = "/" + service
}
prunedIgnoredServices[idx] = service
}
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) {
if !svc.Enabled() {
return handler(ctx, req)
}
interceptorCtx, span := tracing.NewServerInterceptorSpan(ctx)
defer func() { span.EndWithError(err) }()
for _, service := range prunedIgnoredServices {
if strings.HasPrefix(info.FullMethod, service) {
return handler(ctx, req)
}
}
instance := authz.GetInstance(ctx)
remaining := svc.Limit(interceptorCtx, instance.InstanceID())
if remaining != nil && *remaining == 0 {
return nil, errors.ThrowResourceExhausted(nil, "QUOTA-vjAy8", "Quota.Access.Exhausted")
}
span.End()
return handler(ctx, req)
}
}

View File

@@ -2,7 +2,6 @@ package server
import (
"crypto/tls"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
@@ -10,6 +9,7 @@ import (
"github.com/zitadel/zitadel/internal/api/authz"
grpc_api "github.com/zitadel/zitadel/internal/api/grpc"
"github.com/zitadel/zitadel/internal/api/grpc/server/middleware"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/metrics"
system_pb "github.com/zitadel/zitadel/pkg/grpc/system"
@@ -23,7 +23,14 @@ type Server interface {
AuthMethods() authz.MethodMapping
}
func CreateServer(verifier *authz.TokenVerifier, authConfig authz.Config, queries *query.Queries, hostHeaderName string, tlsConfig *tls.Config) *grpc.Server {
func CreateServer(
verifier *authz.TokenVerifier,
authConfig authz.Config,
queries *query.Queries,
hostHeaderName string,
tlsConfig *tls.Config,
accessSvc *logstore.Service,
) *grpc.Server {
metricTypes := []metrics.MetricType{metrics.MetricTypeTotalCount, metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode}
serverOptions := []grpc.ServerOption{
grpc.UnaryInterceptor(
@@ -33,10 +40,12 @@ func CreateServer(verifier *authz.TokenVerifier, authConfig authz.Config, querie
middleware.NoCacheInterceptor(),
middleware.ErrorHandler(),
middleware.InstanceInterceptor(queries, hostHeaderName, system_pb.SystemService_MethodPrefix),
middleware.AccessStorageInterceptor(accessSvc),
middleware.AuthorizationInterceptor(verifier, authConfig),
middleware.TranslationHandler(),
middleware.ValidationHandler(),
middleware.ServiceHandler(),
middleware.QuotaExhaustedInterceptor(accessSvc, system_pb.SystemService_MethodPrefix),
),
),
}

View File

@@ -31,7 +31,6 @@ func (s *Server) ListInstances(ctx context.Context, req *system_pb.ListInstances
}
func (s *Server) GetInstance(ctx context.Context, req *system_pb.GetInstanceRequest) (*system_pb.GetInstanceResponse, error) {
ctx = authz.WithInstanceID(ctx, req.InstanceId)
instance, err := s.query.Instance(ctx, true)
if err != nil {
return nil, err
@@ -53,7 +52,6 @@ func (s *Server) AddInstance(ctx context.Context, req *system_pb.AddInstanceRequ
}
func (s *Server) UpdateInstance(ctx context.Context, req *system_pb.UpdateInstanceRequest) (*system_pb.UpdateInstanceResponse, error) {
ctx = authz.WithInstanceID(ctx, req.InstanceId)
details, err := s.command.UpdateInstance(ctx, req.InstanceName)
if err != nil {
return nil, err
@@ -86,7 +84,6 @@ func (s *Server) CreateInstance(ctx context.Context, req *system_pb.CreateInstan
}
func (s *Server) RemoveInstance(ctx context.Context, req *system_pb.RemoveInstanceRequest) (*system_pb.RemoveInstanceResponse, error) {
ctx = authz.WithInstanceID(ctx, req.InstanceId)
details, err := s.command.RemoveInstance(ctx, req.InstanceId)
if err != nil {
return nil, err
@@ -97,7 +94,6 @@ func (s *Server) RemoveInstance(ctx context.Context, req *system_pb.RemoveInstan
}
func (s *Server) ListIAMMembers(ctx context.Context, req *system_pb.ListIAMMembersRequest) (*system_pb.ListIAMMembersResponse, error) {
ctx = authz.WithInstanceID(ctx, req.InstanceId)
queries, err := ListIAMMembersRequestToQuery(req)
if err != nil {
return nil, err
@@ -139,7 +135,6 @@ func (s *Server) ExistsDomain(ctx context.Context, req *system_pb.ExistsDomainRe
}
func (s *Server) ListDomains(ctx context.Context, req *system_pb.ListDomainsRequest) (*system_pb.ListDomainsResponse, error) {
ctx = authz.WithInstanceID(ctx, req.InstanceId)
queries, err := ListInstanceDomainsRequestToModel(req)
if err != nil {
return nil, err
@@ -156,8 +151,6 @@ func (s *Server) ListDomains(ctx context.Context, req *system_pb.ListDomainsRequ
}
func (s *Server) AddDomain(ctx context.Context, req *system_pb.AddDomainRequest) (*system_pb.AddDomainResponse, error) {
//TODO: should be solved in interceptor
ctx = authz.WithInstanceID(ctx, req.InstanceId)
instance, err := s.query.Instance(ctx, true)
if err != nil {
return nil, err
@@ -174,7 +167,6 @@ func (s *Server) AddDomain(ctx context.Context, req *system_pb.AddDomainRequest)
}
func (s *Server) RemoveDomain(ctx context.Context, req *system_pb.RemoveDomainRequest) (*system_pb.RemoveDomainResponse, error) {
ctx = authz.WithInstanceID(ctx, req.InstanceId)
details, err := s.command.RemoveInstanceDomain(ctx, req.Domain)
if err != nil {
return nil, err
@@ -185,7 +177,6 @@ func (s *Server) RemoveDomain(ctx context.Context, req *system_pb.RemoveDomainRe
}
func (s *Server) SetPrimaryDomain(ctx context.Context, req *system_pb.SetPrimaryDomainRequest) (*system_pb.SetPrimaryDomainResponse, error) {
ctx = authz.WithInstanceID(ctx, req.InstanceId)
details, err := s.command.SetPrimaryInstanceDomain(ctx, req.Domain)
if err != nil {
return nil, err

View File

@@ -0,0 +1,32 @@
package system
import (
"context"
"github.com/zitadel/zitadel/internal/api/grpc/object"
"github.com/zitadel/zitadel/pkg/grpc/system"
system_pb "github.com/zitadel/zitadel/pkg/grpc/system"
)
func (s *Server) AddQuota(ctx context.Context, req *system.AddQuotaRequest) (*system.AddQuotaResponse, error) {
details, err := s.command.AddQuota(
ctx,
instanceQuotaPbToCommand(req),
)
if err != nil {
return nil, err
}
return &system_pb.AddQuotaResponse{
Details: object.AddToDetailsPb(details.Sequence, details.EventDate, details.ResourceOwner),
}, nil
}
func (s *Server) RemoveQuota(ctx context.Context, req *system.RemoveQuotaRequest) (*system.RemoveQuotaResponse, error) {
details, err := s.command.RemoveQuota(ctx, instanceQuotaUnitPbToCommand(req.Unit))
if err != nil {
return nil, err
}
return &system_pb.RemoveQuotaResponse{
Details: object.ChangeToDetailsPb(details.Sequence, details.EventDate, details.ResourceOwner),
}, nil
}

View File

@@ -0,0 +1,43 @@
package system
import (
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/pkg/grpc/quota"
"github.com/zitadel/zitadel/pkg/grpc/system"
)
func instanceQuotaPbToCommand(req *system.AddQuotaRequest) *command.AddQuota {
return &command.AddQuota{
Unit: instanceQuotaUnitPbToCommand(req.Unit),
From: req.From.AsTime(),
ResetInterval: req.ResetInterval.AsDuration(),
Amount: req.Amount,
Limit: req.Limit,
Notifications: instanceQuotaNotificationsPbToCommand(req.Notifications),
}
}
func instanceQuotaUnitPbToCommand(unit quota.Unit) command.QuotaUnit {
switch unit {
case quota.Unit_UNIT_REQUESTS_ALL_AUTHENTICATED:
return command.QuotaRequestsAllAuthenticated
case quota.Unit_UNIT_ACTIONS_ALL_RUN_SECONDS:
return command.QuotaActionsAllRunsSeconds
case quota.Unit_UNIT_UNIMPLEMENTED:
fallthrough
default:
return command.QuotaUnit(unit.String())
}
}
func instanceQuotaNotificationsPbToCommand(req []*quota.Notification) command.QuotaNotifications {
notifications := make([]*command.QuotaNotification, len(req))
for idx, item := range req {
notifications[idx] = &command.QuotaNotification{
Percent: uint16(item.Percent),
Repeat: item.Repeat,
CallURL: item.CallUrl,
}
}
return notifications
}

View File

@@ -72,7 +72,9 @@ func WithPath(path string) CookieHandlerOpt {
func WithMaxAge(maxAge int) CookieHandlerOpt {
return func(c *CookieHandler) {
c.maxAge = maxAge
c.securecookie.MaxAge(maxAge)
if c.securecookie != nil {
c.securecookie.MaxAge(maxAge)
}
}
}

View File

@@ -0,0 +1,102 @@
package middleware
import (
"math"
"net/http"
"net/url"
"strconv"
"time"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
http_utils "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/logstore/emitters/access"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
type AccessInterceptor struct {
svc *logstore.Service
cookieHandler *http_utils.CookieHandler
limitConfig *AccessConfig
}
type AccessConfig struct {
ExhaustedCookieKey string
ExhaustedCookieMaxAge time.Duration
}
func NewAccessInterceptor(svc *logstore.Service, cookieConfig *AccessConfig) *AccessInterceptor {
return &AccessInterceptor{
svc: svc,
cookieHandler: http_utils.NewCookieHandler(
http_utils.WithUnsecure(),
http_utils.WithMaxAge(int(math.Floor(cookieConfig.ExhaustedCookieMaxAge.Seconds()))),
),
limitConfig: cookieConfig,
}
}
func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
if !a.svc.Enabled() {
return next
}
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
ctx := request.Context()
var err error
tracingCtx, span := tracing.NewServerInterceptorSpan(ctx)
defer func() { span.EndWithError(err) }()
wrappedWriter := &statusRecorder{ResponseWriter: writer, status: 0}
instance := authz.GetInstance(ctx)
remaining := a.svc.Limit(tracingCtx, instance.InstanceID())
limit := remaining != nil && *remaining == 0
a.cookieHandler.SetCookie(wrappedWriter, a.limitConfig.ExhaustedCookieKey, request.Host, strconv.FormatBool(limit))
if limit {
wrappedWriter.WriteHeader(http.StatusTooManyRequests)
wrappedWriter.ignoreWrites = true
}
next.ServeHTTP(wrappedWriter, request)
requestURL := request.RequestURI
unescapedURL, err := url.QueryUnescape(requestURL)
if err != nil {
logging.WithError(err).WithField("url", requestURL).Warning("failed to unescape request url")
// err = nil is effective because of deferred tracing span end
err = nil
}
a.svc.Handle(tracingCtx, &access.Record{
LogDate: time.Now(),
Protocol: access.HTTP,
RequestURL: unescapedURL,
ResponseStatus: uint32(wrappedWriter.status),
RequestHeaders: request.Header,
ResponseHeaders: writer.Header(),
InstanceID: instance.InstanceID(),
ProjectID: instance.ProjectID(),
RequestedDomain: instance.RequestedDomain(),
RequestedHost: instance.RequestedHost(),
})
})
}
type statusRecorder struct {
http.ResponseWriter
status int
ignoreWrites bool
}
func (r *statusRecorder) WriteHeader(status int) {
if r.ignoreWrites {
return
}
r.status = status
r.ResponseWriter.WriteHeader(status)
}

View File

@@ -244,6 +244,10 @@ func (m *mockInstanceVerifier) InstanceByHost(_ context.Context, host string) (a
return &mockInstance{}, nil
}
func (m *mockInstanceVerifier) InstanceByID(context.Context) (authz.Instance, error) {
return nil, nil
}
type mockInstance struct{}
func (m *mockInstance) InstanceID() string {

View File

@@ -427,7 +427,7 @@ func (o *OPStorage) userinfoFlows(ctx context.Context, resourceOwner string, use
apiFields,
action.Script,
action.Name,
append(actions.ActionToOptions(action), actions.WithHTTP(actionCtx), actions.WithLogger(actions.ServerLog))...,
append(actions.ActionToOptions(action), actions.WithHTTP(actionCtx))...,
)
cancel()
if err != nil {
@@ -583,7 +583,7 @@ func (o *OPStorage) privateClaimsFlows(ctx context.Context, userID string, claim
apiFields,
action.Script,
action.Name,
append(actions.ActionToOptions(action), actions.WithHTTP(actionCtx), actions.WithLogger(actions.ServerLog))...,
append(actions.ActionToOptions(action), actions.WithHTTP(actionCtx))...,
)
cancel()
if err != nil {

View File

@@ -73,13 +73,13 @@ type OPStorage struct {
assetAPIPrefix func(ctx context.Context) string
}
func NewProvider(ctx context.Context, config Config, defaultLogoutRedirectURI string, externalSecure bool, command *command.Commands, query *query.Queries, repo repository.Repository, encryptionAlg crypto.EncryptionAlgorithm, cryptoKey []byte, es *eventstore.Eventstore, projections *sql.DB, userAgentCookie, instanceHandler func(http.Handler) http.Handler) (op.OpenIDProvider, error) {
func NewProvider(ctx context.Context, config Config, defaultLogoutRedirectURI string, externalSecure bool, command *command.Commands, query *query.Queries, repo repository.Repository, encryptionAlg crypto.EncryptionAlgorithm, cryptoKey []byte, es *eventstore.Eventstore, projections *sql.DB, userAgentCookie, instanceHandler, accessHandler func(http.Handler) http.Handler) (op.OpenIDProvider, error) {
opConfig, err := createOPConfig(config, defaultLogoutRedirectURI, cryptoKey)
if err != nil {
return nil, caos_errs.ThrowInternal(err, "OIDC-EGrqd", "cannot create op config: %w")
}
storage := newStorage(config, command, query, repo, encryptionAlg, es, projections, externalSecure)
options, err := createOptions(config, externalSecure, userAgentCookie, instanceHandler)
options, err := createOptions(config, externalSecure, userAgentCookie, instanceHandler, accessHandler)
if err != nil {
return nil, caos_errs.ThrowInternal(err, "OIDC-D3gq1", "cannot create options: %w")
}
@@ -117,7 +117,7 @@ func createOPConfig(config Config, defaultLogoutRedirectURI string, cryptoKey []
return opConfig, nil
}
func createOptions(config Config, externalSecure bool, userAgentCookie, instanceHandler func(http.Handler) http.Handler) ([]op.Option, error) {
func createOptions(config Config, externalSecure bool, userAgentCookie, instanceHandler, accessHandler func(http.Handler) http.Handler) ([]op.Option, error) {
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
options := []op.Option{
op.WithHttpInterceptors(
@@ -127,6 +127,7 @@ func createOptions(config Config, externalSecure bool, userAgentCookie, instance
instanceHandler,
userAgentCookie,
http_utils.CopyHeadersToContext,
accessHandler,
),
}
if !externalSecure {

View File

@@ -40,7 +40,8 @@ func NewProvider(
es *eventstore.Eventstore,
projections *sql.DB,
instanceHandler,
userAgentCookie func(http.Handler) http.Handler,
userAgentCookie,
accessHandler func(http.Handler) http.Handler,
) (*provider.Provider, error) {
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
@@ -64,6 +65,7 @@ func NewProvider(
middleware.NoCacheInterceptor().Handler,
instanceHandler,
userAgentCookie,
accessHandler,
http_utils.CopyHeadersToContext,
),
}

View File

@@ -88,7 +88,7 @@ func (f *file) Stat() (_ fs.FileInfo, err error) {
return f, nil
}
func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, instanceHandler func(http.Handler) http.Handler, customerPortal string) (http.Handler, error) {
func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, instanceHandler, accessInterceptor func(http.Handler) http.Handler, customerPortal string) (http.Handler, error) {
fSys, err := fs.Sub(static, "static")
if err != nil {
return nil, err
@@ -103,7 +103,7 @@ func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, inst
handler := mux.NewRouter()
handler.Use(instanceHandler, security)
handler.Use(instanceHandler, security, accessInterceptor)
handler.Handle(envRequestPath, middleware.TelemetryHandler()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
url := http_util.BuildOrigin(r.Host, externalSecure)
environmentJSON, err := createEnvironmentJSON(url, issuer(r), authz.GetInstance(r.Context()).ConsoleClientID(), customerPortal)

View File

@@ -106,7 +106,7 @@ func (l *Login) runPostExternalAuthenticationActions(
apiFields,
a.Script,
a.Name,
append(actions.ActionToOptions(a), actions.WithHTTP(actionCtx), actions.WithLogger(actions.ServerLog))...,
append(actions.ActionToOptions(a), actions.WithHTTP(actionCtx))...,
)
cancel()
if err != nil {
@@ -175,7 +175,7 @@ func (l *Login) runPostInternalAuthenticationActions(
apiFields,
a.Script,
a.Name,
append(actions.ActionToOptions(a), actions.WithHTTP(actionCtx), actions.WithLogger(actions.ServerLog))...,
append(actions.ActionToOptions(a), actions.WithHTTP(actionCtx))...,
)
cancel()
if err != nil {
@@ -274,7 +274,7 @@ func (l *Login) runPreCreationActions(
apiFields,
a.Script,
a.Name,
append(actions.ActionToOptions(a), actions.WithHTTP(actionCtx), actions.WithLogger(actions.ServerLog))...,
append(actions.ActionToOptions(a), actions.WithHTTP(actionCtx))...,
)
cancel()
if err != nil {
@@ -332,7 +332,7 @@ func (l *Login) runPostCreationActions(
apiFields,
a.Script,
a.Name,
append(actions.ActionToOptions(a), actions.WithHTTP(actionCtx), actions.WithLogger(actions.ServerLog))...,
append(actions.ActionToOptions(a), actions.WithHTTP(actionCtx))...,
)
cancel()
if err != nil {

View File

@@ -69,6 +69,7 @@ func CreateLogin(config Config,
oidcInstanceHandler,
samlInstanceHandler mux.MiddlewareFunc,
assetCache mux.MiddlewareFunc,
accessHandler mux.MiddlewareFunc,
userCodeAlg crypto.EncryptionAlgorithm,
idpConfigAlg crypto.EncryptionAlgorithm,
csrfCookieKey []byte,
@@ -94,7 +95,7 @@ func CreateLogin(config Config,
cacheInterceptor := createCacheInterceptor(config.Cache.MaxAge, config.Cache.SharedMaxAge, assetCache)
security := middleware.SecurityHeaders(csp(), login.cspErrorHandler)
login.router = CreateRouter(login, statikFS, middleware.TelemetryHandler(IgnoreInstanceEndpoints...), oidcInstanceHandler, samlInstanceHandler, csrfInterceptor, cacheInterceptor, security, userAgentCookie, issuerInterceptor)
login.router = CreateRouter(login, statikFS, middleware.TelemetryHandler(IgnoreInstanceEndpoints...), oidcInstanceHandler, samlInstanceHandler, csrfInterceptor, cacheInterceptor, security, userAgentCookie, issuerInterceptor, accessHandler)
login.renderer = CreateRenderer(HandlerPrefix, statikFS, staticStorage, config.LanguageCookieName)
login.parser = form.NewParser()
return login, nil