mirror of
https://github.com/zitadel/zitadel.git
synced 2025-02-28 22:37:24 +00:00
perf: project quotas and usages (#6441)
* project quota added * project quota removed * add periods table * make log record generic * accumulate usage * query usage * count action run seconds * fix filter in ReportQuotaUsage * fix existing tests * fix logstore tests * fix typo * fix: add quota unit tests command side * fix: add quota unit tests command side * fix: add quota unit tests command side * move notifications into debouncer and improve limit querying * cleanup * comment * fix: add quota unit tests command side * fix remaining quota usage query * implement InmemLogStorage * cleanup and linting * improve test * fix: add quota unit tests command side * fix: add quota unit tests command side * fix: add quota unit tests command side * fix: add quota unit tests command side * action notifications and fixes for notifications query * revert console prefix * fix: add quota unit tests command side * fix: add quota integration tests * improve accountable requests * improve accountable requests * fix: add quota integration tests * fix: add quota integration tests * fix: add quota integration tests * comment * remove ability to store logs in db and other changes requested from review * changes requested from review * changes requested from review * Update internal/api/http/middleware/access_interceptor.go Co-authored-by: Silvan <silvan.reusser@gmail.com> * tests: fix quotas integration tests * improve incrementUsageStatement * linting * fix: delete e2e tests as intergation tests cover functionality * Update internal/api/http/middleware/access_interceptor.go Co-authored-by: Silvan <silvan.reusser@gmail.com> * backup * fix conflict * create rc * create prerelease * remove issue release labeling * fix tracing --------- Co-authored-by: Livio Spring <livio.a@gmail.com> Co-authored-by: Stefan Benz <stefan@caos.ch> Co-authored-by: adlerhurst <silvan.reusser@gmail.com> (cherry picked from commit 1a49b7d298690ce64846727f1fcf5a325f77c76e)
This commit is contained in:
parent
b688d6f842
commit
5823fdbef9
@ -452,54 +452,29 @@ Actions:
|
||||
|
||||
LogStore:
|
||||
Access:
|
||||
Database:
|
||||
# If enabled, all access logs are stored in the database table logstore.access
|
||||
Enabled: false # ZITADEL_LOGSTORE_ACCESS_DATABASE_ENABLED
|
||||
# Logs that are older than the keep duration are cleaned up continuously
|
||||
# 2160h are 90 days, 3 months
|
||||
Keep: 2160h # ZITADEL_LOGSTORE_ACCESS_DATABASE_KEEP
|
||||
# CleanupInterval defines the time between cleanup iterations
|
||||
CleanupInterval: 4h # ZITADEL_LOGSTORE_ACCESS_DATABASE_CLEANUPINTERVAL
|
||||
# Debouncing enables to asynchronously emit log entries, so the normal execution performance is not impaired
|
||||
# Log entries are held in memory until one of the conditions MinFrequency or MaxBulkSize meets.
|
||||
Debounce:
|
||||
MinFrequency: 2m # ZITADEL_LOGSTORE_ACCESS_DATABASE_DEBOUNCE_MINFREQUENCY
|
||||
MaxBulkSize: 100 # ZITADEL_LOGSTORE_ACCESS_DATABASE_DEBOUNCE_MAXBULKSIZE
|
||||
Stdout:
|
||||
# If enabled, all access logs are printed to the binary's standard output
|
||||
Enabled: false # ZITADEL_LOGSTORE_ACCESS_STDOUT_ENABLED
|
||||
# Debouncing enables to asynchronously emit log entries, so the normal execution performance is not impaired
|
||||
# Log entries are held in memory until one of the conditions MinFrequency or MaxBulkSize meets.
|
||||
Debounce:
|
||||
MinFrequency: 0s # ZITADEL_LOGSTORE_ACCESS_STDOUT_DEBOUNCE_MINFREQUENCY
|
||||
MaxBulkSize: 0 # ZITADEL_LOGSTORE_ACCESS_STDOUT_DEBOUNCE_MAXBULKSIZE
|
||||
Execution:
|
||||
Database:
|
||||
# If enabled, all action execution logs are stored in the database table logstore.execution
|
||||
Enabled: false # ZITADEL_LOGSTORE_EXECUTION_DATABASE_ENABLED
|
||||
# Logs that are older than the keep duration are cleaned up continuously
|
||||
# 2160h are 90 days, 3 months
|
||||
Keep: 2160h # ZITADEL_LOGSTORE_EXECUTION_DATABASE_KEEP
|
||||
# CleanupInterval defines the time between cleanup iterations
|
||||
CleanupInterval: 4h # ZITADEL_LOGSTORE_EXECUTION_DATABASE_CLEANUPINTERVAL
|
||||
# Debouncing enables to asynchronously emit log entries, so the normal execution performance is not impaired
|
||||
# Log entries are held in memory until one of the conditions MinFrequency or MaxBulkSize meets.
|
||||
Debounce:
|
||||
MinFrequency: 0s # ZITADEL_LOGSTORE_EXECUTION_DATABASE_DEBOUNCE_MINFREQUENCY
|
||||
MaxBulkSize: 0 # ZITADEL_LOGSTORE_EXECUTION_DATABASE_DEBOUNCE_MAXBULKSIZE
|
||||
Stdout:
|
||||
# If enabled, all execution logs are printed to the binary's standard output
|
||||
Enabled: true # ZITADEL_LOGSTORE_EXECUTION_STDOUT_ENABLED
|
||||
# Debouncing enables to asynchronously emit log entries, so the normal execution performance is not impaired
|
||||
# Log entries are held in memory until one of the conditions MinFrequency or MaxBulkSize meets.
|
||||
Debounce:
|
||||
MinFrequency: 0s # ZITADEL_LOGSTORE_EXECUTION_STDOUT_DEBOUNCE_MINFREQUENCY
|
||||
MaxBulkSize: 0 # ZITADEL_LOGSTORE_EXECUTION_STDOUT_DEBOUNCE_MAXBULKSIZE
|
||||
|
||||
Quotas:
|
||||
Access:
|
||||
# If enabled, authenticated requests are counted and potentially limited depending on the configured quota of the instance
|
||||
Enabled: false # ZITADEL_QUOTAS_ACCESS_ENABLED
|
||||
Debounce:
|
||||
MinFrequency: 0s # ZITADEL_QUOTAS_ACCESS_DEBOUNCE_MINFREQUENCY
|
||||
MaxBulkSize: 0 # ZITADEL_QUOTAS_ACCESS_DEBOUNCE_MAXBULKSIZE
|
||||
ExhaustedCookieKey: "zitadel.quota.exhausted" # ZITADEL_QUOTAS_ACCESS_EXHAUSTEDCOOKIEKEY
|
||||
ExhaustedCookieMaxAge: "300s" # ZITADEL_QUOTAS_ACCESS_EXHAUSTEDCOOKIEMAXAGE
|
||||
Execution:
|
||||
# If enabled, all action executions are counted and potentially limited depending on the configured quota of the instance
|
||||
Enabled: false # ZITADEL_QUOTAS_EXECUTION_DATABASE_ENABLED
|
||||
Debounce:
|
||||
MinFrequency: 0s # ZITADEL_QUOTAS_EXECUTION_DEBOUNCE_MINFREQUENCY
|
||||
MaxBulkSize: 0 # ZITADEL_QUOTAS_EXECUTION_DEBOUNCE_MAXBULKSIZE
|
||||
|
||||
Eventstore:
|
||||
PushTimeout: 15s # ZITADEL_EVENTSTORE_PUSHTIMEOUT
|
||||
|
@ -72,7 +72,11 @@ type Config struct {
|
||||
}
|
||||
|
||||
type QuotasConfig struct {
|
||||
Access *middleware.AccessConfig
|
||||
Access struct {
|
||||
logstore.EmitterConfig `mapstructure:",squash"`
|
||||
middleware.AccessConfig `mapstructure:",squash"`
|
||||
}
|
||||
Execution *logstore.EmitterConfig
|
||||
}
|
||||
|
||||
func MustNewConfig(v *viper.Viper) *Config {
|
||||
|
@ -64,6 +64,8 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/logstore/emitters/access"
|
||||
"github.com/zitadel/zitadel/internal/logstore/emitters/execution"
|
||||
"github.com/zitadel/zitadel/internal/logstore/emitters/stdout"
|
||||
"github.com/zitadel/zitadel/internal/logstore/record"
|
||||
"github.com/zitadel/zitadel/internal/net"
|
||||
"github.com/zitadel/zitadel/internal/notification"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/static"
|
||||
@ -108,7 +110,6 @@ type Server struct {
|
||||
AuthzRepo authz_repo.Repository
|
||||
Storage static.Storage
|
||||
Commands *command.Commands
|
||||
LogStore *logstore.Service
|
||||
Router *mux.Router
|
||||
TLSConfig *tls.Config
|
||||
Shutdown chan<- os.Signal
|
||||
@ -209,17 +210,16 @@ func startZitadel(config *Config, masterKey string, server chan<- *Server) error
|
||||
}
|
||||
|
||||
clock := clockpkg.New()
|
||||
actionsExecutionStdoutEmitter, err := logstore.NewEmitter(ctx, clock, config.LogStore.Execution.Stdout, stdout.NewStdoutEmitter())
|
||||
actionsExecutionStdoutEmitter, err := logstore.NewEmitter[*record.ExecutionLog](ctx, clock, &logstore.EmitterConfig{Enabled: config.LogStore.Execution.Stdout.Enabled}, stdout.NewStdoutEmitter[*record.ExecutionLog]())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
actionsExecutionDBEmitter, err := logstore.NewEmitter(ctx, clock, config.LogStore.Execution.Database, execution.NewDatabaseLogStorage(dbClient))
|
||||
actionsExecutionDBEmitter, err := logstore.NewEmitter[*record.ExecutionLog](ctx, clock, config.Quotas.Execution, execution.NewDatabaseLogStorage(dbClient, commands, queries))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
usageReporter := logstore.UsageReporterFunc(commands.ReportQuotaUsage)
|
||||
actionsLogstoreSvc := logstore.New(queries, usageReporter, actionsExecutionDBEmitter, actionsExecutionStdoutEmitter)
|
||||
actionsLogstoreSvc := logstore.New(queries, actionsExecutionDBEmitter, actionsExecutionStdoutEmitter)
|
||||
actions.SetLogstoreService(actionsLogstoreSvc)
|
||||
|
||||
notification.Start(
|
||||
@ -259,8 +259,6 @@ func startZitadel(config *Config, masterKey string, server chan<- *Server) error
|
||||
storage,
|
||||
authZRepo,
|
||||
keys,
|
||||
queries,
|
||||
usageReporter,
|
||||
permissionCheck,
|
||||
)
|
||||
if err != nil {
|
||||
@ -281,7 +279,6 @@ func startZitadel(config *Config, masterKey string, server chan<- *Server) error
|
||||
AuthzRepo: authZRepo,
|
||||
Storage: storage,
|
||||
Commands: commands,
|
||||
LogStore: actionsLogstoreSvc,
|
||||
Router: router,
|
||||
TLSConfig: tlsConfig,
|
||||
Shutdown: shutdown,
|
||||
@ -304,8 +301,6 @@ func startAPIs(
|
||||
store static.Storage,
|
||||
authZRepo authz_repo.Repository,
|
||||
keys *encryptionKeys,
|
||||
quotaQuerier logstore.QuotaQuerier,
|
||||
usageReporter logstore.UsageReporter,
|
||||
permissionCheck domain.PermissionCheck,
|
||||
) error {
|
||||
repo := struct {
|
||||
@ -321,22 +316,22 @@ func startAPIs(
|
||||
return err
|
||||
}
|
||||
|
||||
accessStdoutEmitter, err := logstore.NewEmitter(ctx, clock, config.LogStore.Access.Stdout, stdout.NewStdoutEmitter())
|
||||
accessStdoutEmitter, err := logstore.NewEmitter[*record.AccessLog](ctx, clock, &logstore.EmitterConfig{Enabled: config.LogStore.Access.Stdout.Enabled}, stdout.NewStdoutEmitter[*record.AccessLog]())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
accessDBEmitter, err := logstore.NewEmitter(ctx, clock, config.LogStore.Access.Database, access.NewDatabaseLogStorage(dbClient))
|
||||
accessDBEmitter, err := logstore.NewEmitter[*record.AccessLog](ctx, clock, &config.Quotas.Access.EmitterConfig, access.NewDatabaseLogStorage(dbClient, commands, queries))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
accessSvc := logstore.New(quotaQuerier, usageReporter, accessDBEmitter, accessStdoutEmitter)
|
||||
accessSvc := logstore.New[*record.AccessLog](queries, accessDBEmitter, accessStdoutEmitter)
|
||||
exhaustedCookieHandler := http_util.NewCookieHandler(
|
||||
http_util.WithUnsecure(),
|
||||
http_util.WithNonHttpOnly(),
|
||||
http_util.WithMaxAge(int(math.Floor(config.Quotas.Access.ExhaustedCookieMaxAge.Seconds()))),
|
||||
)
|
||||
limitingAccessInterceptor := middleware.NewAccessInterceptor(accessSvc, exhaustedCookieHandler, config.Quotas.Access)
|
||||
limitingAccessInterceptor := middleware.NewAccessInterceptor(accessSvc, exhaustedCookieHandler, &config.Quotas.Access.AccessConfig)
|
||||
apis, err := api.New(ctx, config.Port, router, queries, verifier, config.InternalAuthZ, tlsConfig, config.HTTP2HostHeader, config.HTTP1HostHeader, limitingAccessInterceptor)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating api %w", err)
|
||||
@ -399,13 +394,13 @@ func startAPIs(
|
||||
}
|
||||
apis.RegisterHandlerOnPrefix(openapi.HandlerPrefix, openAPIHandler)
|
||||
|
||||
oidcProvider, err := oidc.NewProvider(config.OIDC, login.DefaultLoggedOutPath, config.ExternalSecure, commands, queries, authRepo, keys.OIDC, keys.OIDCKey, eventstore, dbClient, userAgentInterceptor, instanceInterceptor.Handler, limitingAccessInterceptor.Handle)
|
||||
oidcProvider, err := oidc.NewProvider(config.OIDC, login.DefaultLoggedOutPath, config.ExternalSecure, commands, queries, authRepo, keys.OIDC, keys.OIDCKey, eventstore, dbClient, userAgentInterceptor, instanceInterceptor.Handler, limitingAccessInterceptor)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to start oidc provider: %w", err)
|
||||
}
|
||||
apis.RegisterHandlerPrefixes(oidcProvider.HttpHandler(), "/.well-known/openid-configuration", "/oidc/v1", "/oauth/v2")
|
||||
|
||||
samlProvider, err := saml.NewProvider(config.SAML, config.ExternalSecure, commands, queries, authRepo, keys.OIDC, keys.SAML, eventstore, dbClient, instanceInterceptor.Handler, userAgentInterceptor, limitingAccessInterceptor.Handle)
|
||||
samlProvider, err := saml.NewProvider(config.SAML, config.ExternalSecure, commands, queries, authRepo, keys.OIDC, keys.SAML, eventstore, dbClient, instanceInterceptor.Handler, userAgentInterceptor, limitingAccessInterceptor)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to start saml provider: %w", err)
|
||||
}
|
||||
@ -417,7 +412,7 @@ func startAPIs(
|
||||
}
|
||||
apis.RegisterHandlerOnPrefix(console.HandlerPrefix, c)
|
||||
|
||||
l, err := login.CreateLogin(config.Login, commands, queries, authRepo, store, console.HandlerPrefix+"/", op.AuthCallbackURL(oidcProvider), provider.AuthCallbackURL(samlProvider), config.ExternalSecure, userAgentInterceptor, op.NewIssuerInterceptor(oidcProvider.IssuerFromRequest).Handler, provider.NewIssuerInterceptor(samlProvider.IssuerFromRequest).Handler, instanceInterceptor.Handler, assetsCache.Handler, limitingAccessInterceptor.Handle, keys.User, keys.IDPConfig, keys.CSRFCookieKey)
|
||||
l, err := login.CreateLogin(config.Login, commands, queries, authRepo, store, console.HandlerPrefix+"/", op.AuthCallbackURL(oidcProvider), provider.AuthCallbackURL(samlProvider), config.ExternalSecure, userAgentInterceptor, op.NewIssuerInterceptor(oidcProvider.IssuerFromRequest).Handler, provider.NewIssuerInterceptor(samlProvider.IssuerFromRequest).Handler, instanceInterceptor.Handler, assetsCache.Handler, limitingAccessInterceptor.WithoutLimiting().Handle, keys.User, keys.IDPConfig, keys.CSRFCookieKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to start login: %w", err)
|
||||
}
|
||||
@ -438,7 +433,7 @@ func listen(ctx context.Context, router *mux.Router, port uint16, tlsConfig *tls
|
||||
http2Server := &http2.Server{}
|
||||
http1Server := &http.Server{Handler: h2c.NewHandler(router, http2Server), TLSConfig: tlsConfig}
|
||||
|
||||
lc := listenConfig()
|
||||
lc := net.ListenConfig()
|
||||
lis, err := lc.Listen(ctx, "tcp", fmt.Sprintf(":%d", port))
|
||||
if err != nil {
|
||||
return fmt.Errorf("tcp listener on %d failed: %w", port, err)
|
||||
|
@ -28,12 +28,7 @@ LogStore:
|
||||
Database:
|
||||
Enabled: true
|
||||
Stdout:
|
||||
Enabled: false
|
||||
|
||||
Quotas:
|
||||
Access:
|
||||
ExhaustedCookieKey: "zitadel.quota.limiting"
|
||||
ExhaustedCookieMaxAge: "600s"
|
||||
Enabled: true
|
||||
|
||||
Console:
|
||||
InstanceManagementURL: "https://example.com/instances/{{.InstanceID}}"
|
||||
@ -43,9 +38,25 @@ Projections:
|
||||
NotificationsQuotas:
|
||||
RequeueEvery: 1s
|
||||
|
||||
Quotas:
|
||||
Access:
|
||||
ExhaustedCookieKey: "zitadel.quota.limiting"
|
||||
ExhaustedCookieMaxAge: "600s"
|
||||
|
||||
DefaultInstance:
|
||||
LoginPolicy:
|
||||
MfaInitSkipLifetime: "0"
|
||||
Quotas:
|
||||
Items:
|
||||
- Unit: "actions.all.runs.seconds"
|
||||
From: "2023-01-01T00:00:00Z"
|
||||
ResetInterval: 5m
|
||||
Amount: 20
|
||||
Limit: false
|
||||
Notifications:
|
||||
- Percent: 100
|
||||
Repeat: true
|
||||
CallURL: "https://httpbin.org/post"
|
||||
|
||||
SystemAPIUsers:
|
||||
- cypress:
|
||||
|
@ -1,324 +0,0 @@
|
||||
import { addQuota, ensureQuotaIsAdded, ensureQuotaIsRemoved, removeQuota, Unit } from 'support/api/quota';
|
||||
import { createHumanUser, ensureUserDoesntExist } from 'support/api/users';
|
||||
import { Context } from 'support/commands';
|
||||
import { ZITADELWebhookEvent } from 'support/types';
|
||||
import { textChangeRangeIsUnchanged } from 'typescript';
|
||||
|
||||
beforeEach(() => {
|
||||
cy.context().as('ctx');
|
||||
});
|
||||
|
||||
describe('quotas', () => {
|
||||
describe('management', () => {
|
||||
describe('add one quota', () => {
|
||||
it('should add a quota only once per unit', () => {
|
||||
cy.get<Context>('@ctx').then((ctx) => {
|
||||
addQuota(ctx, Unit.AuthenticatedRequests, true, 1);
|
||||
addQuota(ctx, Unit.AuthenticatedRequests, true, 1, undefined, undefined, undefined, false).then((res) => {
|
||||
expect(res.status).to.equal(409);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('add two quotas', () => {
|
||||
it('should add a quota for each unit', () => {
|
||||
cy.get<Context>('@ctx').then((ctx) => {
|
||||
addQuota(ctx, Unit.AuthenticatedRequests, true, 1);
|
||||
addQuota(ctx, Unit.ExecutionSeconds, true, 1);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('edit', () => {
|
||||
describe('remove one quota', () => {
|
||||
beforeEach(() => {
|
||||
cy.get<Context>('@ctx').then((ctx) => {
|
||||
ensureQuotaIsAdded(ctx, Unit.AuthenticatedRequests, true, 1);
|
||||
});
|
||||
});
|
||||
it('should remove a quota only once per unit', () => {
|
||||
cy.get<Context>('@ctx').then((ctx) => {
|
||||
removeQuota(ctx, Unit.AuthenticatedRequests);
|
||||
});
|
||||
cy.get<Context>('@ctx').then((ctx) => {
|
||||
removeQuota(ctx, Unit.AuthenticatedRequests, false).then((res) => {
|
||||
expect(res.status).to.equal(404);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('remove two quotas', () => {
|
||||
beforeEach(() => {
|
||||
cy.get<Context>('@ctx').then((ctx) => {
|
||||
ensureQuotaIsAdded(ctx, Unit.AuthenticatedRequests, true, 1);
|
||||
ensureQuotaIsAdded(ctx, Unit.ExecutionSeconds, true, 1);
|
||||
});
|
||||
});
|
||||
it('should remove a quota for each unit', () => {
|
||||
cy.get<Context>('@ctx').then((ctx) => {
|
||||
removeQuota(ctx, Unit.AuthenticatedRequests);
|
||||
removeQuota(ctx, Unit.ExecutionSeconds);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('usage', () => {
|
||||
beforeEach(() => {
|
||||
cy.get<Context>('@ctx')
|
||||
.then((ctx) => {
|
||||
return [
|
||||
`${ctx.api.oidcBaseURL}/userinfo`,
|
||||
`${ctx.api.authBaseURL}/users/me`,
|
||||
`${ctx.api.mgmtBaseURL}/iam`,
|
||||
`${ctx.api.adminBaseURL}/instances/me`,
|
||||
`${ctx.api.oauthBaseURL}/keys`,
|
||||
`${ctx.api.samlBaseURL}/certificate`,
|
||||
];
|
||||
})
|
||||
.as('authenticatedUrls');
|
||||
});
|
||||
|
||||
describe('authenticated requests', () => {
|
||||
const testUserName = 'shouldNotBeCreated';
|
||||
beforeEach(() => {
|
||||
cy.get<Array<string>>('@authenticatedUrls').then((urls) => {
|
||||
cy.get<Context>('@ctx').then((ctx) => {
|
||||
ensureUserDoesntExist(ctx.api, testUserName);
|
||||
ensureQuotaIsAdded(ctx, Unit.AuthenticatedRequests, true, urls.length);
|
||||
cy.task('runSQL', `TRUNCATE logstore.access;`);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it('only authenticated requests are limited', () => {
|
||||
cy.get<Array<string>>('@authenticatedUrls').then((urls) => {
|
||||
cy.get<Context>('@ctx').then((ctx) => {
|
||||
const start = new Date();
|
||||
urls.forEach((url) => {
|
||||
cy.request({
|
||||
url: url,
|
||||
method: 'GET',
|
||||
auth: {
|
||||
bearer: ctx.api.token,
|
||||
},
|
||||
});
|
||||
});
|
||||
expectCookieDoesntExist();
|
||||
const expiresMax = new Date();
|
||||
expiresMax.setMinutes(expiresMax.getMinutes() + 20);
|
||||
cy.request({
|
||||
url: urls[1],
|
||||
method: 'GET',
|
||||
auth: {
|
||||
bearer: ctx.api.token,
|
||||
},
|
||||
failOnStatusCode: false,
|
||||
}).then((res) => {
|
||||
expect(res.status).to.equal(429);
|
||||
});
|
||||
cy.getCookie('zitadel.quota.limiting').then((cookie) => {
|
||||
expect(cookie.value).to.equal('true');
|
||||
const cookieExpiry = new Date();
|
||||
cookieExpiry.setTime(cookie.expiry * 1000);
|
||||
expect(cookieExpiry).to.be.within(start, expiresMax);
|
||||
});
|
||||
createHumanUser(ctx.api, testUserName, false).then((res) => {
|
||||
expect(res.status).to.equal(429);
|
||||
});
|
||||
// visit limited console
|
||||
// cy.visit('/users/me');
|
||||
// cy.contains('#authenticated-requests-exhausted-dialog button', 'Continue').click();
|
||||
// const upgradeInstancePage = `https://example.com/instances/${ctx.instanceId}`;
|
||||
// cy.origin(upgradeInstancePage, { args: { upgradeInstancePage } }, ({ upgradeInstancePage }) => {
|
||||
// cy.location('href').should('equal', upgradeInstancePage);
|
||||
// });
|
||||
// upgrade instance
|
||||
ensureQuotaIsRemoved(ctx, Unit.AuthenticatedRequests);
|
||||
// visit upgraded console again
|
||||
cy.visit('/users/me');
|
||||
cy.get('[data-e2e="top-view-title"]');
|
||||
expectCookieDoesntExist();
|
||||
createHumanUser(ctx.api, testUserName);
|
||||
expectCookieDoesntExist();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe.skip('notifications', () => {
|
||||
const callURL = `http://${Cypress.env('WEBHOOK_HANDLER_HOST')}:${Cypress.env('WEBHOOK_HANDLER_PORT')}/do_something`;
|
||||
|
||||
beforeEach(() => cy.task('resetWebhookEvents'));
|
||||
|
||||
const amount = 100;
|
||||
const percent = 10;
|
||||
const usage = 35;
|
||||
|
||||
describe('without repetition', () => {
|
||||
beforeEach(() => {
|
||||
cy.get<Context>('@ctx').then((ctx) => {
|
||||
ensureQuotaIsAdded(ctx, Unit.AuthenticatedRequests, false, amount, [
|
||||
{
|
||||
callUrl: callURL,
|
||||
percent: percent,
|
||||
repeat: false,
|
||||
},
|
||||
]);
|
||||
cy.task('runSQL', `TRUNCATE logstore.access;`);
|
||||
});
|
||||
});
|
||||
|
||||
it('fires at least once with the expected payload', () => {
|
||||
cy.get<Array<string>>('@authenticatedUrls').then((urls) => {
|
||||
cy.get<Context>('@ctx').then((ctx) => {
|
||||
for (let i = 0; i < usage; i++) {
|
||||
cy.request({
|
||||
url: urls[0],
|
||||
method: 'GET',
|
||||
auth: {
|
||||
bearer: ctx.api.token,
|
||||
},
|
||||
});
|
||||
}
|
||||
});
|
||||
cy.waitUntil(
|
||||
() =>
|
||||
cy.task<Array<ZITADELWebhookEvent>>('handledWebhookEvents').then((events) => {
|
||||
if (events.length < 1) {
|
||||
return false;
|
||||
}
|
||||
return Cypress._.matches(<ZITADELWebhookEvent>{
|
||||
sentStatus: 200,
|
||||
payload: {
|
||||
callURL: callURL,
|
||||
threshold: percent,
|
||||
unit: 1,
|
||||
usage: percent,
|
||||
},
|
||||
})(events[0]);
|
||||
}),
|
||||
{ timeout: 60_000 },
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('fires until the webhook returns a successful message', () => {
|
||||
cy.task('failWebhookEvents', 8);
|
||||
cy.get<Array<string>>('@authenticatedUrls').then((urls) => {
|
||||
cy.get<Context>('@ctx').then((ctx) => {
|
||||
for (let i = 0; i < usage; i++) {
|
||||
cy.request({
|
||||
url: urls[0],
|
||||
method: 'GET',
|
||||
auth: {
|
||||
bearer: ctx.api.token,
|
||||
},
|
||||
});
|
||||
}
|
||||
});
|
||||
cy.waitUntil(
|
||||
() =>
|
||||
cy.task<Array<ZITADELWebhookEvent>>('handledWebhookEvents').then((events) => {
|
||||
if (events.length != 9) {
|
||||
return false;
|
||||
}
|
||||
return events.reduce<boolean>((a, b, i) => {
|
||||
return !a
|
||||
? a
|
||||
: i < 8
|
||||
? Cypress._.matches(<ZITADELWebhookEvent>{
|
||||
sentStatus: 500,
|
||||
payload: {
|
||||
callURL: callURL,
|
||||
threshold: percent,
|
||||
unit: 1,
|
||||
usage: percent,
|
||||
},
|
||||
})(b)
|
||||
: Cypress._.matches(<ZITADELWebhookEvent>{
|
||||
sentStatus: 200,
|
||||
payload: {
|
||||
callURL: callURL,
|
||||
threshold: percent,
|
||||
unit: 1,
|
||||
usage: percent,
|
||||
},
|
||||
})(b);
|
||||
}, true);
|
||||
}),
|
||||
{ timeout: 60_000 },
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('with repetition', () => {
|
||||
beforeEach(() => {
|
||||
cy.get<Array<string>>('@authenticatedUrls').then((urls) => {
|
||||
cy.get<Context>('@ctx').then((ctx) => {
|
||||
ensureQuotaIsAdded(ctx, Unit.AuthenticatedRequests, false, amount, [
|
||||
{
|
||||
callUrl: callURL,
|
||||
percent: percent,
|
||||
repeat: true,
|
||||
},
|
||||
]);
|
||||
cy.task('runSQL', `TRUNCATE logstore.access;`);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it('fires repeatedly with the expected payloads', () => {
|
||||
cy.get<Array<string>>('@authenticatedUrls').then((urls) => {
|
||||
cy.get<Context>('@ctx').then((ctx) => {
|
||||
for (let i = 0; i < usage; i++) {
|
||||
cy.request({
|
||||
url: urls[0],
|
||||
method: 'GET',
|
||||
auth: {
|
||||
bearer: ctx.api.token,
|
||||
},
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
cy.waitUntil(
|
||||
() =>
|
||||
cy.task<Array<ZITADELWebhookEvent>>('handledWebhookEvents').then((events) => {
|
||||
let foundExpected = 0;
|
||||
for (let i = 0; i < events.length; i++) {
|
||||
for (let expect = 10; expect <= 30; expect += 10) {
|
||||
if (
|
||||
Cypress._.matches(<ZITADELWebhookEvent>{
|
||||
sentStatus: 200,
|
||||
payload: {
|
||||
callURL: callURL,
|
||||
threshold: expect,
|
||||
unit: 1,
|
||||
usage: expect,
|
||||
},
|
||||
})(events[i])
|
||||
) {
|
||||
foundExpected++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return foundExpected >= 3;
|
||||
}),
|
||||
{ timeout: 60_000 },
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
function expectCookieDoesntExist() {
|
||||
cy.getCookie('zitadel.quota.limiting').then((cookie) => {
|
||||
expect(cookie).to.be.null;
|
||||
});
|
||||
}
|
@ -7,11 +7,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/dop251/goja"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
"github.com/zitadel/zitadel/internal/logstore/record"
|
||||
)
|
||||
|
||||
func TestRun(t *testing.T) {
|
||||
SetLogstoreService(logstore.New(nil, nil, nil))
|
||||
SetLogstoreService(logstore.New[*record.ExecutionLog](nil, nil))
|
||||
type args struct {
|
||||
timeout time.Duration
|
||||
api apiFields
|
||||
|
@ -5,11 +5,13 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/dop251/goja"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
"github.com/zitadel/zitadel/internal/logstore/record"
|
||||
)
|
||||
|
||||
func TestSetFields(t *testing.T) {
|
||||
SetLogstoreService(logstore.New(nil, nil, nil))
|
||||
SetLogstoreService(logstore.New[*record.ExecutionLog](nil, nil))
|
||||
primitveFn := func(a string) { fmt.Println(a) }
|
||||
complexFn := func(*FieldConfig) interface{} {
|
||||
return primitveFn
|
||||
|
@ -10,12 +10,14 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/dop251/goja"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
"github.com/zitadel/zitadel/internal/logstore/record"
|
||||
)
|
||||
|
||||
func Test_isHostBlocked(t *testing.T) {
|
||||
SetLogstoreService(logstore.New(nil, nil, nil))
|
||||
SetLogstoreService(logstore.New[*record.ExecutionLog](nil, nil))
|
||||
var denyList = []AddressChecker{
|
||||
mustNewIPChecker(t, "192.168.5.0/24"),
|
||||
mustNewIPChecker(t, "127.0.0.1"),
|
||||
|
@ -10,15 +10,15 @@ import (
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
"github.com/zitadel/zitadel/internal/logstore/emitters/execution"
|
||||
"github.com/zitadel/zitadel/internal/logstore/record"
|
||||
)
|
||||
|
||||
var (
|
||||
logstoreService *logstore.Service
|
||||
logstoreService *logstore.Service[*record.ExecutionLog]
|
||||
_ console.Printer = (*logger)(nil)
|
||||
)
|
||||
|
||||
func SetLogstoreService(svc *logstore.Service) {
|
||||
func SetLogstoreService(svc *logstore.Service[*record.ExecutionLog]) {
|
||||
logstoreService = svc
|
||||
}
|
||||
|
||||
@ -55,19 +55,16 @@ func (l *logger) log(msg string, level logrus.Level, last bool) {
|
||||
if l.started.IsZero() {
|
||||
l.started = ts
|
||||
}
|
||||
|
||||
record := &execution.Record{
|
||||
r := &record.ExecutionLog{
|
||||
LogDate: ts,
|
||||
InstanceID: l.instanceID,
|
||||
Message: msg,
|
||||
LogLevel: level,
|
||||
}
|
||||
|
||||
if last {
|
||||
record.Took = ts.Sub(l.started)
|
||||
r.Took = ts.Sub(l.started)
|
||||
}
|
||||
|
||||
logstoreService.Handle(l.ctx, record)
|
||||
logstoreService.Handle(l.ctx, r)
|
||||
}
|
||||
|
||||
func withLogger(ctx context.Context) Option {
|
||||
|
@ -10,11 +10,11 @@ import (
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
"github.com/zitadel/zitadel/internal/logstore/emitters/access"
|
||||
"github.com/zitadel/zitadel/internal/logstore/record"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
func AccessStorageInterceptor(svc *logstore.Service) grpc.UnaryServerInterceptor {
|
||||
func AccessStorageInterceptor(svc *logstore.Service[*record.AccessLog]) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) {
|
||||
if !svc.Enabled() {
|
||||
return handler(ctx, req)
|
||||
@ -36,9 +36,9 @@ func AccessStorageInterceptor(svc *logstore.Service) grpc.UnaryServerInterceptor
|
||||
resMd, _ := metadata.FromOutgoingContext(ctx)
|
||||
instance := authz.GetInstance(ctx)
|
||||
|
||||
record := &access.Record{
|
||||
r := &record.AccessLog{
|
||||
LogDate: time.Now(),
|
||||
Protocol: access.GRPC,
|
||||
Protocol: record.GRPC,
|
||||
RequestURL: info.FullMethod,
|
||||
ResponseStatus: respStatus,
|
||||
RequestHeaders: reqMd,
|
||||
@ -49,7 +49,7 @@ func AccessStorageInterceptor(svc *logstore.Service) grpc.UnaryServerInterceptor
|
||||
RequestedHost: instance.RequestedHost(),
|
||||
}
|
||||
|
||||
svc.Handle(interceptorCtx, record)
|
||||
svc.Handle(interceptorCtx, r)
|
||||
return resp, handlerErr
|
||||
}
|
||||
}
|
||||
|
@ -9,19 +9,16 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
"github.com/zitadel/zitadel/internal/logstore/record"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
func QuotaExhaustedInterceptor(svc *logstore.Service, ignoreService ...string) grpc.UnaryServerInterceptor {
|
||||
|
||||
prunedIgnoredServices := make([]string, len(ignoreService))
|
||||
func QuotaExhaustedInterceptor(svc *logstore.Service[*record.AccessLog], ignoreService ...string) grpc.UnaryServerInterceptor {
|
||||
for idx, service := range ignoreService {
|
||||
if !strings.HasPrefix(service, "/") {
|
||||
service = "/" + service
|
||||
ignoreService[idx] = "/" + service
|
||||
}
|
||||
prunedIgnoredServices[idx] = service
|
||||
}
|
||||
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) {
|
||||
if !svc.Enabled() {
|
||||
return handler(ctx, req)
|
||||
@ -29,7 +26,13 @@ func QuotaExhaustedInterceptor(svc *logstore.Service, ignoreService ...string) g
|
||||
interceptorCtx, span := tracing.NewServerInterceptorSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
for _, service := range prunedIgnoredServices {
|
||||
// The auth interceptor will ensure that only authorized or public requests are allowed.
|
||||
// So if there's no authorization context, we don't need to check for limitation
|
||||
if authz.GetCtxData(ctx).IsZero() {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
for _, service := range ignoreService {
|
||||
if strings.HasPrefix(info.FullMethod, service) {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
grpc_api "github.com/zitadel/zitadel/internal/api/grpc"
|
||||
"github.com/zitadel/zitadel/internal/api/grpc/server/middleware"
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
"github.com/zitadel/zitadel/internal/logstore/record"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/metrics"
|
||||
system_pb "github.com/zitadel/zitadel/pkg/grpc/system"
|
||||
@ -39,7 +40,7 @@ func CreateServer(
|
||||
queries *query.Queries,
|
||||
hostHeaderName string,
|
||||
tlsConfig *tls.Config,
|
||||
accessSvc *logstore.Service,
|
||||
accessSvc *logstore.Service[*record.AccessLog],
|
||||
) *grpc.Server {
|
||||
metricTypes := []metrics.MetricType{metrics.MetricTypeTotalCount, metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode}
|
||||
serverOptions := []grpc.ServerOption{
|
||||
@ -49,14 +50,14 @@ func CreateServer(
|
||||
middleware.DefaultTracingServer(),
|
||||
middleware.MetricsHandler(metricTypes, grpc_api.Probes...),
|
||||
middleware.NoCacheInterceptor(),
|
||||
middleware.ErrorHandler(),
|
||||
middleware.InstanceInterceptor(queries, hostHeaderName, system_pb.SystemService_ServiceDesc.ServiceName, healthpb.Health_ServiceDesc.ServiceName),
|
||||
middleware.AccessStorageInterceptor(accessSvc),
|
||||
middleware.ErrorHandler(),
|
||||
middleware.AuthorizationInterceptor(verifier, authConfig),
|
||||
middleware.QuotaExhaustedInterceptor(accessSvc, system_pb.SystemService_ServiceDesc.ServiceName),
|
||||
middleware.TranslationHandler(),
|
||||
middleware.ValidationHandler(),
|
||||
middleware.ServiceHandler(),
|
||||
middleware.QuotaExhaustedInterceptor(accessSvc, system_pb.SystemService_ServiceDesc.ServiceName),
|
||||
),
|
||||
),
|
||||
}
|
||||
|
203
internal/api/grpc/system/quota_integration_test.go
Normal file
203
internal/api/grpc/system/quota_integration_test.go
Normal file
@ -0,0 +1,203 @@
|
||||
//go:build integration
|
||||
|
||||
package system_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
"github.com/zitadel/zitadel/internal/repository/quota"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/admin"
|
||||
quota_pb "github.com/zitadel/zitadel/pkg/grpc/quota"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/system"
|
||||
)
|
||||
|
||||
var callURL = "http://localhost:" + integration.PortQuotaServer
|
||||
|
||||
func TestServer_QuotaNotification_Limit(t *testing.T) {
|
||||
_, instanceID, iamOwnerCtx := Tester.UseIsolatedInstance(CTX, SystemCTX)
|
||||
amount := 10
|
||||
percent := 50
|
||||
percentAmount := amount * percent / 100
|
||||
|
||||
_, err := Tester.Client.System.AddQuota(SystemCTX, &system.AddQuotaRequest{
|
||||
InstanceId: instanceID,
|
||||
Unit: quota_pb.Unit_UNIT_REQUESTS_ALL_AUTHENTICATED,
|
||||
From: timestamppb.Now(),
|
||||
ResetInterval: durationpb.New(time.Minute * 5),
|
||||
Amount: uint64(amount),
|
||||
Limit: true,
|
||||
Notifications: []*quota_pb.Notification{
|
||||
{
|
||||
Percent: uint32(percent),
|
||||
Repeat: true,
|
||||
CallUrl: callURL,
|
||||
},
|
||||
{
|
||||
Percent: 100,
|
||||
Repeat: true,
|
||||
CallUrl: callURL,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
for i := 0; i < percentAmount; i++ {
|
||||
_, err := Tester.Client.Admin.GetDefaultOrg(iamOwnerCtx, &admin.GetDefaultOrgRequest{})
|
||||
require.NoErrorf(t, err, "error in %d call of %d", i, percentAmount)
|
||||
}
|
||||
awaitNotification(t, Tester.QuotaNotificationChan, quota.RequestsAllAuthenticated, percent)
|
||||
|
||||
for i := 0; i < (amount - percentAmount); i++ {
|
||||
_, err := Tester.Client.Admin.GetDefaultOrg(iamOwnerCtx, &admin.GetDefaultOrgRequest{})
|
||||
require.NoErrorf(t, err, "error in %d call of %d", i, percentAmount)
|
||||
}
|
||||
awaitNotification(t, Tester.QuotaNotificationChan, quota.RequestsAllAuthenticated, 100)
|
||||
|
||||
_, limitErr := Tester.Client.Admin.GetDefaultOrg(iamOwnerCtx, &admin.GetDefaultOrgRequest{})
|
||||
require.Error(t, limitErr)
|
||||
}
|
||||
|
||||
func TestServer_QuotaNotification_NoLimit(t *testing.T) {
|
||||
_, instanceID, iamOwnerCtx := Tester.UseIsolatedInstance(CTX, SystemCTX)
|
||||
amount := 10
|
||||
percent := 50
|
||||
percentAmount := amount * percent / 100
|
||||
|
||||
_, err := Tester.Client.System.AddQuota(SystemCTX, &system.AddQuotaRequest{
|
||||
InstanceId: instanceID,
|
||||
Unit: quota_pb.Unit_UNIT_REQUESTS_ALL_AUTHENTICATED,
|
||||
From: timestamppb.Now(),
|
||||
ResetInterval: durationpb.New(time.Minute * 5),
|
||||
Amount: uint64(amount),
|
||||
Limit: false,
|
||||
Notifications: []*quota_pb.Notification{
|
||||
{
|
||||
Percent: uint32(percent),
|
||||
Repeat: false,
|
||||
CallUrl: callURL,
|
||||
},
|
||||
{
|
||||
Percent: 100,
|
||||
Repeat: true,
|
||||
CallUrl: callURL,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
for i := 0; i < percentAmount; i++ {
|
||||
_, err := Tester.Client.Admin.GetDefaultOrg(iamOwnerCtx, &admin.GetDefaultOrgRequest{})
|
||||
require.NoErrorf(t, err, "error in %d call of %d", i, percentAmount)
|
||||
}
|
||||
awaitNotification(t, Tester.QuotaNotificationChan, quota.RequestsAllAuthenticated, percent)
|
||||
|
||||
for i := 0; i < (amount - percentAmount); i++ {
|
||||
_, err := Tester.Client.Admin.GetDefaultOrg(iamOwnerCtx, &admin.GetDefaultOrgRequest{})
|
||||
require.NoErrorf(t, err, "error in %d call of %d", i, percentAmount)
|
||||
}
|
||||
awaitNotification(t, Tester.QuotaNotificationChan, quota.RequestsAllAuthenticated, 100)
|
||||
|
||||
for i := 0; i < amount; i++ {
|
||||
_, err := Tester.Client.Admin.GetDefaultOrg(iamOwnerCtx, &admin.GetDefaultOrgRequest{})
|
||||
require.NoErrorf(t, err, "error in %d call of %d", i, percentAmount)
|
||||
}
|
||||
awaitNotification(t, Tester.QuotaNotificationChan, quota.RequestsAllAuthenticated, 200)
|
||||
|
||||
_, limitErr := Tester.Client.Admin.GetDefaultOrg(iamOwnerCtx, &admin.GetDefaultOrgRequest{})
|
||||
require.NoError(t, limitErr)
|
||||
}
|
||||
|
||||
func awaitNotification(t *testing.T, bodies chan []byte, unit quota.Unit, percent int) {
|
||||
for {
|
||||
select {
|
||||
case body := <-bodies:
|
||||
plain := new(bytes.Buffer)
|
||||
if err := json.Indent(plain, body, "", " "); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("received notificationDueEvent", plain.String())
|
||||
event := struct {
|
||||
Unit quota.Unit `json:"unit"`
|
||||
ID string `json:"id"`
|
||||
CallURL string `json:"callURL"`
|
||||
PeriodStart time.Time `json:"periodStart"`
|
||||
Threshold uint16 `json:"threshold"`
|
||||
Usage uint64 `json:"usage"`
|
||||
}{}
|
||||
if err := json.Unmarshal(body, &event); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if event.ID == "" {
|
||||
continue
|
||||
}
|
||||
if event.Unit == unit && event.Threshold == uint16(percent) {
|
||||
return
|
||||
}
|
||||
case <-time.After(60 * time.Second):
|
||||
t.Fatalf("timed out waiting for unit %s and percent %d", strconv.Itoa(int(unit)), percent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_AddAndRemoveQuota(t *testing.T) {
|
||||
_, instanceID, _ := Tester.UseIsolatedInstance(CTX, SystemCTX)
|
||||
|
||||
got, err := Tester.Client.System.AddQuota(SystemCTX, &system.AddQuotaRequest{
|
||||
InstanceId: instanceID,
|
||||
Unit: quota_pb.Unit_UNIT_REQUESTS_ALL_AUTHENTICATED,
|
||||
From: timestamppb.Now(),
|
||||
ResetInterval: durationpb.New(time.Minute),
|
||||
Amount: 10,
|
||||
Limit: true,
|
||||
Notifications: []*quota_pb.Notification{
|
||||
{
|
||||
Percent: 20,
|
||||
Repeat: true,
|
||||
CallUrl: callURL,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, got.Details.ResourceOwner, instanceID)
|
||||
|
||||
gotAlreadyExisting, errAlreadyExisting := Tester.Client.System.AddQuota(SystemCTX, &system.AddQuotaRequest{
|
||||
InstanceId: instanceID,
|
||||
Unit: quota_pb.Unit_UNIT_REQUESTS_ALL_AUTHENTICATED,
|
||||
From: timestamppb.Now(),
|
||||
ResetInterval: durationpb.New(time.Minute),
|
||||
Amount: 10,
|
||||
Limit: true,
|
||||
Notifications: []*quota_pb.Notification{
|
||||
{
|
||||
Percent: 20,
|
||||
Repeat: true,
|
||||
CallUrl: callURL,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.Error(t, errAlreadyExisting)
|
||||
require.Nil(t, gotAlreadyExisting)
|
||||
|
||||
gotRemove, errRemove := Tester.Client.System.RemoveQuota(SystemCTX, &system.RemoveQuotaRequest{
|
||||
InstanceId: instanceID,
|
||||
Unit: quota_pb.Unit_UNIT_REQUESTS_ALL_AUTHENTICATED,
|
||||
})
|
||||
require.NoError(t, errRemove)
|
||||
require.Equal(t, gotRemove.Details.ResourceOwner, instanceID)
|
||||
|
||||
gotRemoveAlready, errRemoveAlready := Tester.Client.System.RemoveQuota(SystemCTX, &system.RemoveQuotaRequest{
|
||||
InstanceId: instanceID,
|
||||
Unit: quota_pb.Unit_UNIT_REQUESTS_ALL_AUTHENTICATED,
|
||||
})
|
||||
require.Error(t, errRemoveAlready)
|
||||
require.Nil(t, gotRemoveAlready)
|
||||
}
|
32
internal/api/grpc/system/server_integration_test.go
Normal file
32
internal/api/grpc/system/server_integration_test.go
Normal file
@ -0,0 +1,32 @@
|
||||
//go:build integration
|
||||
|
||||
package system_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
)
|
||||
|
||||
var (
|
||||
CTX context.Context
|
||||
SystemCTX context.Context
|
||||
Tester *integration.Tester
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(func() int {
|
||||
ctx, _, cancel := integration.Contexts(5 * time.Minute)
|
||||
defer cancel()
|
||||
CTX = ctx
|
||||
|
||||
Tester = integration.NewTester(ctx)
|
||||
defer Tester.Done()
|
||||
|
||||
SystemCTX = Tester.WithAuthorization(ctx, integration.SystemUser)
|
||||
return m.Run()
|
||||
}())
|
||||
}
|
@ -14,12 +14,12 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/api/grpc/server/middleware"
|
||||
http_utils "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
"github.com/zitadel/zitadel/internal/logstore/emitters/access"
|
||||
"github.com/zitadel/zitadel/internal/logstore/record"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
type AccessInterceptor struct {
|
||||
svc *logstore.Service
|
||||
svc *logstore.Service[*record.AccessLog]
|
||||
cookieHandler *http_utils.CookieHandler
|
||||
limitConfig *AccessConfig
|
||||
storeOnly bool
|
||||
@ -33,7 +33,7 @@ type AccessConfig struct {
|
||||
// NewAccessInterceptor intercepts all requests and stores them to the logstore.
|
||||
// If storeOnly is false, it also checks if requests are exhausted.
|
||||
// If requests are exhausted, it also returns http.StatusTooManyRequests and sets a cookie
|
||||
func NewAccessInterceptor(svc *logstore.Service, cookieHandler *http_utils.CookieHandler, cookieConfig *AccessConfig) *AccessInterceptor {
|
||||
func NewAccessInterceptor(svc *logstore.Service[*record.AccessLog], cookieHandler *http_utils.CookieHandler, cookieConfig *AccessConfig) *AccessInterceptor {
|
||||
return &AccessInterceptor{
|
||||
svc: svc,
|
||||
cookieHandler: cookieHandler,
|
||||
@ -50,7 +50,7 @@ func (a *AccessInterceptor) WithoutLimiting() *AccessInterceptor {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AccessInterceptor) AccessService() *logstore.Service {
|
||||
func (a *AccessInterceptor) AccessService() *logstore.Service[*record.AccessLog] {
|
||||
return a.svc
|
||||
}
|
||||
|
||||
@ -81,14 +81,32 @@ func (a *AccessInterceptor) DeleteExhaustedCookie(writer http.ResponseWriter) {
|
||||
a.cookieHandler.DeleteCookie(writer, a.limitConfig.ExhaustedCookieKey)
|
||||
}
|
||||
|
||||
func (a *AccessInterceptor) HandleIgnorePathPrefixes(ignoredPathPrefixes []string) func(next http.Handler) http.Handler {
|
||||
return a.handle(ignoredPathPrefixes...)
|
||||
}
|
||||
|
||||
func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
|
||||
return a.handle()(next)
|
||||
}
|
||||
|
||||
func (a *AccessInterceptor) handle(ignoredPathPrefixes ...string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
if !a.svc.Enabled() {
|
||||
return next
|
||||
}
|
||||
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
ctx := request.Context()
|
||||
tracingCtx, checkSpan := tracing.NewNamedSpan(ctx, "checkAccess")
|
||||
tracingCtx, checkSpan := tracing.NewNamedSpan(ctx, "checkAccessQuota")
|
||||
wrappedWriter := &statusRecorder{ResponseWriter: writer, status: 0}
|
||||
for _, ignoredPathPrefix := range ignoredPathPrefixes {
|
||||
if !strings.HasPrefix(request.RequestURI, ignoredPathPrefix) {
|
||||
continue
|
||||
}
|
||||
checkSpan.End()
|
||||
next.ServeHTTP(wrappedWriter, request)
|
||||
a.writeLog(tracingCtx, wrappedWriter, writer, request, true)
|
||||
return
|
||||
}
|
||||
limited := a.Limit(tracingCtx)
|
||||
checkSpan.End()
|
||||
if limited {
|
||||
@ -101,17 +119,23 @@ func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
|
||||
if !limited {
|
||||
next.ServeHTTP(wrappedWriter, request)
|
||||
}
|
||||
tracingCtx, writeSpan := tracing.NewNamedSpan(tracingCtx, "writeAccess")
|
||||
a.writeLog(tracingCtx, wrappedWriter, writer, request, a.storeOnly)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AccessInterceptor) writeLog(ctx context.Context, wrappedWriter *statusRecorder, writer http.ResponseWriter, request *http.Request, notCountable bool) {
|
||||
ctx, writeSpan := tracing.NewNamedSpan(ctx, "writeAccess")
|
||||
defer writeSpan.End()
|
||||
requestURL := request.RequestURI
|
||||
unescapedURL, err := url.QueryUnescape(requestURL)
|
||||
if err != nil {
|
||||
logging.WithError(err).WithField("url", requestURL).Warning("failed to unescape request url")
|
||||
}
|
||||
instance := authz.GetInstance(tracingCtx)
|
||||
a.svc.Handle(tracingCtx, &access.Record{
|
||||
instance := authz.GetInstance(ctx)
|
||||
a.svc.Handle(ctx, &record.AccessLog{
|
||||
LogDate: time.Now(),
|
||||
Protocol: access.HTTP,
|
||||
Protocol: record.HTTP,
|
||||
RequestURL: unescapedURL,
|
||||
ResponseStatus: uint32(wrappedWriter.status),
|
||||
RequestHeaders: request.Header,
|
||||
@ -120,7 +144,7 @@ func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
|
||||
ProjectID: instance.ProjectID(),
|
||||
RequestedDomain: instance.RequestedDomain(),
|
||||
RequestedHost: instance.RequestedHost(),
|
||||
})
|
||||
NotCountable: notCountable,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/rakyll/statik/fs"
|
||||
"github.com/zitadel/oidc/v2/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v2/pkg/op"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
@ -79,13 +80,32 @@ type OPStorage struct {
|
||||
assetAPIPrefix func(ctx context.Context) string
|
||||
}
|
||||
|
||||
func NewProvider(config Config, defaultLogoutRedirectURI string, externalSecure bool, command *command.Commands, query *query.Queries, repo repository.Repository, encryptionAlg crypto.EncryptionAlgorithm, cryptoKey []byte, es *eventstore.Eventstore, projections *database.DB, userAgentCookie, instanceHandler, accessHandler func(http.Handler) http.Handler) (op.OpenIDProvider, error) {
|
||||
func NewProvider(
|
||||
config Config,
|
||||
defaultLogoutRedirectURI string,
|
||||
externalSecure bool,
|
||||
command *command.Commands,
|
||||
query *query.Queries,
|
||||
repo repository.Repository,
|
||||
encryptionAlg crypto.EncryptionAlgorithm,
|
||||
cryptoKey []byte,
|
||||
es *eventstore.Eventstore,
|
||||
projections *database.DB,
|
||||
userAgentCookie, instanceHandler func(http.Handler) http.Handler,
|
||||
accessHandler *middleware.AccessInterceptor,
|
||||
) (op.OpenIDProvider, error) {
|
||||
opConfig, err := createOPConfig(config, defaultLogoutRedirectURI, cryptoKey)
|
||||
if err != nil {
|
||||
return nil, caos_errs.ThrowInternal(err, "OIDC-EGrqd", "cannot create op config: %w")
|
||||
}
|
||||
storage := newStorage(config, command, query, repo, encryptionAlg, es, projections, externalSecure)
|
||||
options, err := createOptions(config, externalSecure, userAgentCookie, instanceHandler, accessHandler)
|
||||
options, err := createOptions(
|
||||
config,
|
||||
externalSecure,
|
||||
userAgentCookie,
|
||||
instanceHandler,
|
||||
accessHandler.HandleIgnorePathPrefixes(ignoredQuotaLimitEndpoint(config.CustomEndpoints)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, caos_errs.ThrowInternal(err, "OIDC-D3gq1", "cannot create options: %w")
|
||||
}
|
||||
@ -101,6 +121,21 @@ func NewProvider(config Config, defaultLogoutRedirectURI string, externalSecure
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func ignoredQuotaLimitEndpoint(endpoints *EndpointConfig) []string {
|
||||
authURL := op.DefaultEndpoints.Authorization.Relative()
|
||||
keysURL := op.DefaultEndpoints.JwksURI.Relative()
|
||||
if endpoints == nil {
|
||||
return []string{oidc.DiscoveryEndpoint, authURL, keysURL}
|
||||
}
|
||||
if endpoints.Auth != nil && endpoints.Auth.Path != "" {
|
||||
authURL = endpoints.Auth.Path
|
||||
}
|
||||
if endpoints.Keys != nil && endpoints.Keys.Path != "" {
|
||||
keysURL = endpoints.Keys.Path
|
||||
}
|
||||
return []string{oidc.DiscoveryEndpoint, authURL, keysURL}
|
||||
}
|
||||
|
||||
func createOPConfig(config Config, defaultLogoutRedirectURI string, cryptoKey []byte) (*op.Config, error) {
|
||||
supportedLanguages, err := getSupportedLanguages()
|
||||
if err != nil {
|
||||
|
@ -38,8 +38,8 @@ func NewProvider(
|
||||
es *eventstore.Eventstore,
|
||||
projections *database.DB,
|
||||
instanceHandler,
|
||||
userAgentCookie,
|
||||
accessHandler func(http.Handler) http.Handler,
|
||||
userAgentCookie func(http.Handler) http.Handler,
|
||||
accessHandler *middleware.AccessInterceptor,
|
||||
) (*provider.Provider, error) {
|
||||
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
|
||||
|
||||
@ -63,7 +63,7 @@ func NewProvider(
|
||||
middleware.NoCacheInterceptor().Handler,
|
||||
instanceHandler,
|
||||
userAgentCookie,
|
||||
accessHandler,
|
||||
accessHandler.HandleIgnorePathPrefixes(ignoredQuotaLimitEndpoint(conf.ProviderConfig)),
|
||||
http_utils.CopyHeadersToContext,
|
||||
),
|
||||
provider.WithCustomTimeFormat("2006-01-02T15:04:05.999Z"),
|
||||
@ -100,3 +100,22 @@ func newStorage(
|
||||
defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ignoredQuotaLimitEndpoint(config *provider.Config) []string {
|
||||
metadataEndpoint := HandlerPrefix + provider.DefaultMetadataEndpoint
|
||||
certificateEndpoint := HandlerPrefix + provider.DefaultCertificateEndpoint
|
||||
ssoEndpoint := HandlerPrefix + provider.DefaultSingleSignOnEndpoint
|
||||
if config.MetadataConfig != nil && config.MetadataConfig.Path != "" {
|
||||
metadataEndpoint = HandlerPrefix + config.MetadataConfig.Path
|
||||
}
|
||||
if config.IDPConfig == nil || config.IDPConfig.Endpoints == nil {
|
||||
return []string{metadataEndpoint, certificateEndpoint, ssoEndpoint}
|
||||
}
|
||||
if config.IDPConfig.Endpoints.Certificate != nil && config.IDPConfig.Endpoints.Certificate.Relative() != "" {
|
||||
certificateEndpoint = HandlerPrefix + config.IDPConfig.Endpoints.Certificate.Relative()
|
||||
}
|
||||
if config.IDPConfig.Endpoints.SingleSignOn != nil && config.IDPConfig.Endpoints.SingleSignOn.Relative() != "" {
|
||||
ssoEndpoint = HandlerPrefix + config.IDPConfig.Endpoints.SingleSignOn.Relative()
|
||||
}
|
||||
return []string{metadataEndpoint, certificateEndpoint, ssoEndpoint}
|
||||
}
|
||||
|
@ -283,10 +283,7 @@ func (c *Commands) SetUpInstance(ctx context.Context, setup *InstanceSetup) (str
|
||||
if err != nil {
|
||||
return "", "", nil, nil, err
|
||||
}
|
||||
|
||||
quotaAggregate := quota.NewAggregate(quotaId, instanceID, instanceID)
|
||||
|
||||
validations = append(validations, c.AddQuotaCommand(quotaAggregate, q))
|
||||
validations = append(validations, c.AddQuotaCommand(quota.NewAggregate(quotaId, instanceID), q))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -26,6 +26,7 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/repository/oidcsession"
|
||||
"github.com/zitadel/zitadel/internal/repository/org"
|
||||
proj_repo "github.com/zitadel/zitadel/internal/repository/project"
|
||||
quota_repo "github.com/zitadel/zitadel/internal/repository/quota"
|
||||
"github.com/zitadel/zitadel/internal/repository/session"
|
||||
usr_repo "github.com/zitadel/zitadel/internal/repository/user"
|
||||
"github.com/zitadel/zitadel/internal/repository/usergrant"
|
||||
@ -50,6 +51,7 @@ func eventstoreExpect(t *testing.T, expects ...expect) *eventstore.Eventstore {
|
||||
idpintent.RegisterEventMappers(es)
|
||||
authrequest.RegisterEventMappers(es)
|
||||
oidcsession.RegisterEventMappers(es)
|
||||
quota_repo.RegisterEventMappers(es)
|
||||
return es
|
||||
}
|
||||
|
||||
|
@ -21,8 +21,8 @@ const (
|
||||
QuotaActionsAllRunsSeconds QuotaUnit = "actions.all.runs.seconds"
|
||||
)
|
||||
|
||||
func (q *QuotaUnit) Enum() quota.Unit {
|
||||
switch *q {
|
||||
func (q QuotaUnit) Enum() quota.Unit {
|
||||
switch q {
|
||||
case QuotaRequestsAllAuthenticated:
|
||||
return quota.RequestsAllAuthenticated
|
||||
case QuotaActionsAllRunsSeconds:
|
||||
@ -46,14 +46,11 @@ func (c *Commands) AddQuota(
|
||||
if wm.active {
|
||||
return nil, errors.ThrowAlreadyExists(nil, "COMMAND-WDfFf", "Errors.Quota.AlreadyExists")
|
||||
}
|
||||
|
||||
aggregateId, err := c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aggregate := quota.NewAggregate(aggregateId, instanceId, instanceId)
|
||||
|
||||
cmds, err := preparation.PrepareCommands(ctx, c.eventstore.Filter, c.AddQuotaCommand(aggregate, q))
|
||||
cmds, err := preparation.PrepareCommands(ctx, c.eventstore.Filter, c.AddQuotaCommand(quota.NewAggregate(aggregateId, instanceId), q))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -80,7 +77,7 @@ func (c *Commands) RemoveQuota(ctx context.Context, unit QuotaUnit) (*domain.Obj
|
||||
return nil, errors.ThrowNotFound(nil, "COMMAND-WDfFf", "Errors.Quota.NotFound")
|
||||
}
|
||||
|
||||
aggregate := quota.NewAggregate(wm.AggregateID, instanceId, instanceId)
|
||||
aggregate := quota.NewAggregate(wm.AggregateID, instanceId)
|
||||
|
||||
events := []eventstore.Command{
|
||||
quota.NewRemovedEvent(ctx, &aggregate.Aggregate, unit.Enum()),
|
||||
@ -109,6 +106,22 @@ type QuotaNotification struct {
|
||||
|
||||
type QuotaNotifications []*QuotaNotification
|
||||
|
||||
func (q *QuotaNotification) validate() error {
|
||||
u, err := url.Parse(q.CallURL)
|
||||
if err != nil {
|
||||
return errors.ThrowInvalidArgument(err, "QUOTA-bZ0Fj", "Errors.Quota.Invalid.CallURL")
|
||||
}
|
||||
|
||||
if !u.IsAbs() || u.Host == "" {
|
||||
return errors.ThrowInvalidArgument(nil, "QUOTA-HAYmN", "Errors.Quota.Invalid.CallURL")
|
||||
}
|
||||
|
||||
if q.Percent < 1 {
|
||||
return errors.ThrowInvalidArgument(nil, "QUOTA-pBfjq", "Errors.Quota.Invalid.Percent")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *QuotaNotifications) toAddedEventNotifications(idGenerator id.Generator) ([]*quota.AddedEventNotification, error) {
|
||||
if q == nil {
|
||||
return nil, nil
|
||||
@ -144,17 +157,8 @@ type AddQuota struct {
|
||||
|
||||
func (q *AddQuota) validate() error {
|
||||
for _, notification := range q.Notifications {
|
||||
u, err := url.Parse(notification.CallURL)
|
||||
if err != nil {
|
||||
return errors.ThrowInvalidArgument(err, "QUOTA-bZ0Fj", "Errors.Quota.Invalid.CallURL")
|
||||
}
|
||||
|
||||
if !u.IsAbs() || u.Host == "" {
|
||||
return errors.ThrowInvalidArgument(nil, "QUOTA-HAYmN", "Errors.Quota.Invalid.CallURL")
|
||||
}
|
||||
|
||||
if notification.Percent < 1 {
|
||||
return errors.ThrowInvalidArgument(nil, "QUOTA-pBfjq", "Errors.Quota.Invalid.Percent")
|
||||
if err := notification.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@ -169,11 +173,6 @@ func (q *AddQuota) validate() error {
|
||||
if q.ResetInterval < time.Minute {
|
||||
return errors.ThrowInvalidArgument(nil, "QUOTA-R5otd", "Errors.Quota.Invalid.ResetInterval")
|
||||
}
|
||||
|
||||
if !q.Limit && len(q.Notifications) == 0 {
|
||||
return errors.ThrowInvalidArgument(nil, "QUOTA-4Nv68", "Errors.Quota.Invalid.Noop")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -41,6 +41,7 @@ func (wm *quotaWriteModel) Reduce() error {
|
||||
switch e := event.(type) {
|
||||
case *quota.AddedEvent:
|
||||
wm.AggregateID = e.Aggregate().ID
|
||||
wm.ChangeDate = e.CreationDate()
|
||||
wm.active = true
|
||||
case *quota.RemovedEvent:
|
||||
wm.AggregateID = e.Aggregate().ID
|
||||
|
@ -5,16 +5,47 @@ import (
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/quota"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
// ReportQuotaUsage writes a slice of *quota.NotificationDueEvent directly to the eventstore
|
||||
func (c *Commands) ReportQuotaUsage(ctx context.Context, dueNotifications []*quota.NotificationDueEvent) error {
|
||||
cmds := make([]eventstore.Command, len(dueNotifications))
|
||||
for idx, notification := range dueNotifications {
|
||||
cmds[idx] = notification
|
||||
func (c *Commands) ReportQuotaUsage(ctx context.Context, dueNotifications []*quota.NotificationDueEvent) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
cmds := make([]eventstore.Command, 0, len(dueNotifications))
|
||||
for _, notification := range dueNotifications {
|
||||
ctxFilter, spanFilter := tracing.NewNamedSpan(ctx, "filterNotificationDueEvents")
|
||||
events, errFilter := c.eventstore.Filter(
|
||||
ctxFilter,
|
||||
eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
InstanceID(notification.Aggregate().InstanceID).
|
||||
AddQuery().
|
||||
AggregateTypes(quota.AggregateType).
|
||||
AggregateIDs(notification.Aggregate().ID).
|
||||
EventTypes(quota.NotificationDueEventType).
|
||||
EventData(map[string]interface{}{
|
||||
"id": notification.ID,
|
||||
"periodStart": notification.PeriodStart,
|
||||
"threshold": notification.Threshold,
|
||||
}).Builder(),
|
||||
)
|
||||
spanFilter.EndWithError(errFilter)
|
||||
if errFilter != nil {
|
||||
return errFilter
|
||||
}
|
||||
_, err := c.eventstore.Push(ctx, cmds...)
|
||||
return err
|
||||
if len(events) > 0 {
|
||||
continue
|
||||
}
|
||||
cmds = append(cmds, notification)
|
||||
}
|
||||
if len(cmds) == 0 {
|
||||
return nil
|
||||
}
|
||||
ctxPush, spanPush := tracing.NewNamedSpan(ctx, "pushNotificationDueEvents")
|
||||
_, errPush := c.eventstore.Push(ctxPush, cmds...)
|
||||
spanPush.EndWithError(errPush)
|
||||
return errPush
|
||||
}
|
||||
|
||||
func (c *Commands) UsageNotificationSent(ctx context.Context, dueEvent *quota.NotificationDueEvent) error {
|
||||
|
307
internal/command/quota_report_test.go
Normal file
307
internal/command/quota_report_test.go
Normal file
@ -0,0 +1,307 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
id_mock "github.com/zitadel/zitadel/internal/id/mock"
|
||||
"github.com/zitadel/zitadel/internal/repository/quota"
|
||||
)
|
||||
|
||||
func TestQuotaReport_ReportQuotaUsage(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
dueNotifications []*quota.NotificationDueEvent
|
||||
}
|
||||
type res struct {
|
||||
err func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "no due events",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
|
||||
},
|
||||
res: res{},
|
||||
},
|
||||
{
|
||||
name: "due event already reported",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"INSTANCE",
|
||||
quota.NewNotificationDueEvent(context.Background(),
|
||||
"a.NewAggregate("quota1", "INSTANCE").Aggregate,
|
||||
QuotaRequestsAllAuthenticated.Enum(),
|
||||
"id",
|
||||
"url",
|
||||
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
|
||||
1000,
|
||||
200,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
|
||||
dueNotifications: []*quota.NotificationDueEvent{
|
||||
{
|
||||
Unit: QuotaRequestsAllAuthenticated.Enum(),
|
||||
ID: "id",
|
||||
CallURL: "url",
|
||||
PeriodStart: time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
|
||||
Threshold: 1000,
|
||||
Usage: 250,
|
||||
},
|
||||
},
|
||||
},
|
||||
res: res{},
|
||||
},
|
||||
{
|
||||
name: "due event not reported",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(),
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"INSTANCE",
|
||||
quota.NewNotificationDueEvent(context.Background(),
|
||||
"a.NewAggregate("quota1", "INSTANCE").Aggregate,
|
||||
QuotaRequestsAllAuthenticated.Enum(),
|
||||
"id",
|
||||
"url",
|
||||
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
|
||||
1000,
|
||||
250,
|
||||
),
|
||||
),
|
||||
},
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
|
||||
dueNotifications: []*quota.NotificationDueEvent{
|
||||
quota.NewNotificationDueEvent(
|
||||
context.Background(),
|
||||
"a.NewAggregate("quota1", "INSTANCE").Aggregate,
|
||||
QuotaRequestsAllAuthenticated.Enum(),
|
||||
"id",
|
||||
"url",
|
||||
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
|
||||
1000,
|
||||
250,
|
||||
),
|
||||
},
|
||||
},
|
||||
res: res{},
|
||||
},
|
||||
{
|
||||
name: "due events",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(),
|
||||
expectFilter(
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"INSTANCE",
|
||||
quota.NewNotificationDueEvent(context.Background(),
|
||||
"a.NewAggregate("quota1", "INSTANCE").Aggregate,
|
||||
QuotaRequestsAllAuthenticated.Enum(),
|
||||
"id2",
|
||||
"url",
|
||||
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
|
||||
1000,
|
||||
250,
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(),
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"INSTANCE",
|
||||
quota.NewNotificationDueEvent(context.Background(),
|
||||
"a.NewAggregate("quota1", "INSTANCE").Aggregate,
|
||||
QuotaRequestsAllAuthenticated.Enum(),
|
||||
"id1",
|
||||
"url",
|
||||
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
|
||||
1000,
|
||||
250,
|
||||
),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"INSTANCE",
|
||||
quota.NewNotificationDueEvent(context.Background(),
|
||||
"a.NewAggregate("quota1", "INSTANCE").Aggregate,
|
||||
QuotaRequestsAllAuthenticated.Enum(),
|
||||
"id3",
|
||||
"url",
|
||||
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
|
||||
1000,
|
||||
250,
|
||||
),
|
||||
),
|
||||
},
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
|
||||
dueNotifications: []*quota.NotificationDueEvent{
|
||||
quota.NewNotificationDueEvent(
|
||||
context.Background(),
|
||||
"a.NewAggregate("quota1", "INSTANCE").Aggregate,
|
||||
QuotaRequestsAllAuthenticated.Enum(),
|
||||
"id1",
|
||||
"url",
|
||||
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
|
||||
1000,
|
||||
250,
|
||||
),
|
||||
quota.NewNotificationDueEvent(
|
||||
context.Background(),
|
||||
"a.NewAggregate("quota1", "INSTANCE").Aggregate,
|
||||
QuotaRequestsAllAuthenticated.Enum(),
|
||||
"id2",
|
||||
"url",
|
||||
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
|
||||
1000,
|
||||
250,
|
||||
),
|
||||
quota.NewNotificationDueEvent(
|
||||
context.Background(),
|
||||
"a.NewAggregate("quota1", "INSTANCE").Aggregate,
|
||||
QuotaRequestsAllAuthenticated.Enum(),
|
||||
"id3",
|
||||
"url",
|
||||
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
|
||||
1000,
|
||||
250,
|
||||
),
|
||||
},
|
||||
},
|
||||
res: res{},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
}
|
||||
err := r.ReportQuotaUsage(tt.args.ctx, tt.args.dueNotifications)
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
if tt.res.err != nil && !tt.res.err(err) {
|
||||
t.Errorf("got wrong err: %v ", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuotaReport_UsageNotificationSent(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
dueNotification *quota.NotificationDueEvent
|
||||
}
|
||||
type res struct {
|
||||
err func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "usage notification sent, ok",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"INSTANCE",
|
||||
quota.NewNotifiedEvent(
|
||||
context.Background(),
|
||||
"quota1",
|
||||
quota.NewNotificationDueEvent(
|
||||
context.Background(),
|
||||
"a.NewAggregate("quota1", "INSTANCE").Aggregate,
|
||||
QuotaRequestsAllAuthenticated.Enum(),
|
||||
"id1",
|
||||
"url",
|
||||
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
|
||||
1000,
|
||||
250,
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
),
|
||||
),
|
||||
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "quota1"),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
|
||||
dueNotification: quota.NewNotificationDueEvent(
|
||||
context.Background(),
|
||||
"a.NewAggregate("quota1", "INSTANCE").Aggregate,
|
||||
QuotaRequestsAllAuthenticated.Enum(),
|
||||
"id1",
|
||||
"url",
|
||||
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
|
||||
1000,
|
||||
250,
|
||||
),
|
||||
},
|
||||
res: res{},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
}
|
||||
err := r.UsageNotificationSent(tt.args.ctx, tt.args.dueNotification)
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
if tt.res.err != nil && !tt.res.err(err) {
|
||||
t.Errorf("got wrong err: %v ", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
638
internal/command/quota_test.go
Normal file
638
internal/command/quota_test.go
Normal file
@ -0,0 +1,638 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
caos_errors "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
id_mock "github.com/zitadel/zitadel/internal/id/mock"
|
||||
"github.com/zitadel/zitadel/internal/repository/quota"
|
||||
)
|
||||
|
||||
func TestQuota_AddQuota(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
addQuota *AddQuota
|
||||
}
|
||||
type res struct {
|
||||
want *domain.ObjectDetails
|
||||
err func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "already existing",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
quota.NewAddedEvent(context.Background(),
|
||||
"a.NewAggregate("quota1", "INSTANCE").Aggregate,
|
||||
QuotaRequestsAllAuthenticated.Enum(),
|
||||
time.Now(),
|
||||
30*24*time.Hour,
|
||||
1000,
|
||||
false,
|
||||
nil,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
|
||||
addQuota: &AddQuota{
|
||||
Unit: QuotaRequestsAllAuthenticated,
|
||||
From: time.Time{},
|
||||
ResetInterval: 0,
|
||||
Amount: 0,
|
||||
Limit: false,
|
||||
Notifications: nil,
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
err: caos_errors.IsErrorAlreadyExists,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "create quota, validation fail",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(),
|
||||
),
|
||||
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "quota1"),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
|
||||
addQuota: &AddQuota{
|
||||
Unit: "unimplemented",
|
||||
From: time.Time{},
|
||||
ResetInterval: 0,
|
||||
Amount: 0,
|
||||
Limit: false,
|
||||
Notifications: nil,
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-OTeSh", ""))
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "create quota, ok",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(),
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"INSTANCE",
|
||||
quota.NewAddedEvent(context.Background(),
|
||||
"a.NewAggregate("quota1", "INSTANCE").Aggregate,
|
||||
QuotaRequestsAllAuthenticated.Enum(),
|
||||
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
|
||||
30*24*time.Hour,
|
||||
1000,
|
||||
true,
|
||||
nil,
|
||||
),
|
||||
),
|
||||
},
|
||||
uniqueConstraintsFromEventConstraintWithInstanceID("INSTANCE", quota.NewAddQuotaUnitUniqueConstraint(quota.RequestsAllAuthenticated)),
|
||||
),
|
||||
),
|
||||
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "quota1"),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
|
||||
addQuota: &AddQuota{
|
||||
Unit: QuotaRequestsAllAuthenticated,
|
||||
From: time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
|
||||
ResetInterval: 30 * 24 * time.Hour,
|
||||
Amount: 1000,
|
||||
Limit: true,
|
||||
Notifications: nil,
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
want: &domain.ObjectDetails{
|
||||
ResourceOwner: "INSTANCE",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "removed, ok",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"INSTANCE",
|
||||
quota.NewAddedEvent(context.Background(),
|
||||
"a.NewAggregate("quota1", "INSTANCE").Aggregate,
|
||||
QuotaRequestsAllAuthenticated.Enum(),
|
||||
time.Now(),
|
||||
30*24*time.Hour,
|
||||
1000,
|
||||
true,
|
||||
nil,
|
||||
),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"INSTANCE",
|
||||
quota.NewRemovedEvent(context.Background(),
|
||||
"a.NewAggregate("quota1", "INSTANCE").Aggregate,
|
||||
QuotaRequestsAllAuthenticated.Enum(),
|
||||
),
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"INSTANCE",
|
||||
quota.NewAddedEvent(context.Background(),
|
||||
"a.NewAggregate("quota1", "INSTANCE").Aggregate,
|
||||
QuotaRequestsAllAuthenticated.Enum(),
|
||||
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
|
||||
30*24*time.Hour,
|
||||
1000,
|
||||
true,
|
||||
nil,
|
||||
),
|
||||
),
|
||||
},
|
||||
uniqueConstraintsFromEventConstraintWithInstanceID("INSTANCE", quota.NewAddQuotaUnitUniqueConstraint(quota.RequestsAllAuthenticated)),
|
||||
),
|
||||
),
|
||||
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "quota1"),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
|
||||
addQuota: &AddQuota{
|
||||
Unit: QuotaRequestsAllAuthenticated,
|
||||
From: time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
|
||||
ResetInterval: 30 * 24 * time.Hour,
|
||||
Amount: 1000,
|
||||
Limit: true,
|
||||
Notifications: nil,
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
want: &domain.ObjectDetails{
|
||||
ResourceOwner: "INSTANCE",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "create quota with notifications, ok",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(),
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"INSTANCE",
|
||||
quota.NewAddedEvent(context.Background(),
|
||||
"a.NewAggregate("quota1", "INSTANCE").Aggregate,
|
||||
QuotaRequestsAllAuthenticated.Enum(),
|
||||
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
|
||||
30*24*time.Hour,
|
||||
1000,
|
||||
true,
|
||||
[]*quota.AddedEventNotification{
|
||||
{
|
||||
ID: "notification1",
|
||||
Percent: 20,
|
||||
Repeat: false,
|
||||
CallURL: "https://url.com",
|
||||
},
|
||||
},
|
||||
),
|
||||
),
|
||||
},
|
||||
uniqueConstraintsFromEventConstraintWithInstanceID("INSTANCE", quota.NewAddQuotaUnitUniqueConstraint(quota.RequestsAllAuthenticated)),
|
||||
),
|
||||
),
|
||||
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "quota1", "notification1"),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
|
||||
addQuota: &AddQuota{
|
||||
Unit: QuotaRequestsAllAuthenticated,
|
||||
From: time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
|
||||
ResetInterval: 30 * 24 * time.Hour,
|
||||
Amount: 1000,
|
||||
Limit: true,
|
||||
Notifications: QuotaNotifications{
|
||||
{
|
||||
Percent: 20,
|
||||
Repeat: false,
|
||||
CallURL: "https://url.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
want: &domain.ObjectDetails{
|
||||
ResourceOwner: "INSTANCE",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
}
|
||||
got, err := r.AddQuota(tt.args.ctx, tt.args.addQuota)
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
if tt.res.err != nil && !tt.res.err(err) {
|
||||
t.Errorf("got wrong err: %v ", err)
|
||||
}
|
||||
if tt.res.err == nil {
|
||||
assert.Equal(t, tt.res.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuota_RemoveQuota(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
unit QuotaUnit
|
||||
}
|
||||
type res struct {
|
||||
want *domain.ObjectDetails
|
||||
err func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "not found",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
|
||||
unit: QuotaRequestsAllAuthenticated,
|
||||
},
|
||||
res: res{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, caos_errors.ThrowNotFound(nil, "COMMAND-WDfFf", ""))
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "already removed",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"INSTANCE",
|
||||
quota.NewAddedEvent(context.Background(),
|
||||
"a.NewAggregate("quota1", "INSTANCE").Aggregate,
|
||||
QuotaRequestsAllAuthenticated.Enum(),
|
||||
time.Now(),
|
||||
30*24*time.Hour,
|
||||
1000,
|
||||
true,
|
||||
nil,
|
||||
),
|
||||
),
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"INSTANCE",
|
||||
quota.NewRemovedEvent(context.Background(),
|
||||
"a.NewAggregate("quota1", "INSTANCE").Aggregate,
|
||||
QuotaRequestsAllAuthenticated.Enum(),
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
|
||||
unit: QuotaRequestsAllAuthenticated,
|
||||
},
|
||||
res: res{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, caos_errors.ThrowNotFound(nil, "COMMAND-WDfFf", ""))
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "remove quota, ok",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"INSTANCE",
|
||||
quota.NewAddedEvent(context.Background(),
|
||||
"a.NewAggregate("quota1", "INSTANCE").Aggregate,
|
||||
QuotaRequestsAllAuthenticated.Enum(),
|
||||
time.Now(),
|
||||
30*24*time.Hour,
|
||||
1000,
|
||||
false,
|
||||
nil,
|
||||
),
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"INSTANCE",
|
||||
quota.NewRemovedEvent(context.Background(),
|
||||
"a.NewAggregate("quota1", "INSTANCE").Aggregate,
|
||||
QuotaRequestsAllAuthenticated.Enum(),
|
||||
),
|
||||
),
|
||||
},
|
||||
uniqueConstraintsFromEventConstraintWithInstanceID("INSTANCE", quota.NewRemoveQuotaNameUniqueConstraint(quota.RequestsAllAuthenticated)),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
|
||||
unit: QuotaRequestsAllAuthenticated,
|
||||
},
|
||||
res: res{
|
||||
want: &domain.ObjectDetails{
|
||||
ResourceOwner: "INSTANCE",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
}
|
||||
got, err := r.RemoveQuota(tt.args.ctx, tt.args.unit)
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
if tt.res.err != nil && !tt.res.err(err) {
|
||||
t.Errorf("got wrong err: %v ", err)
|
||||
}
|
||||
if tt.res.err == nil {
|
||||
assert.Equal(t, tt.res.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuota_QuotaNotification_validate(t *testing.T) {
|
||||
type args struct {
|
||||
quotaNotification *QuotaNotification
|
||||
}
|
||||
type res struct {
|
||||
err func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "notification url parse failed",
|
||||
args: args{
|
||||
quotaNotification: &QuotaNotification{
|
||||
Percent: 20,
|
||||
Repeat: false,
|
||||
CallURL: "%",
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-bZ0Fj", ""))
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "notification url parse empty schema",
|
||||
args: args{
|
||||
quotaNotification: &QuotaNotification{
|
||||
Percent: 20,
|
||||
Repeat: false,
|
||||
CallURL: "localhost:8080",
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-HAYmN", ""))
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "notification url parse empty host",
|
||||
args: args{
|
||||
quotaNotification: &QuotaNotification{
|
||||
Percent: 20,
|
||||
Repeat: false,
|
||||
CallURL: "https://",
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-HAYmN", ""))
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "notification url parse percent 0",
|
||||
args: args{
|
||||
quotaNotification: &QuotaNotification{
|
||||
Percent: 0,
|
||||
Repeat: false,
|
||||
CallURL: "https://localhost:8080",
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-pBfjq", ""))
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "notification, ok",
|
||||
args: args{
|
||||
quotaNotification: &QuotaNotification{
|
||||
Percent: 20,
|
||||
Repeat: false,
|
||||
CallURL: "https://localhost:8080",
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
err: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.args.quotaNotification.validate()
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
if tt.res.err != nil && !tt.res.err(err) {
|
||||
t.Errorf("got wrong err: %v ", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuota_AddQuota_validate(t *testing.T) {
|
||||
type args struct {
|
||||
addQuota *AddQuota
|
||||
}
|
||||
type res struct {
|
||||
err func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "notification url parse failed",
|
||||
args: args{
|
||||
addQuota: &AddQuota{
|
||||
Unit: QuotaRequestsAllAuthenticated,
|
||||
From: time.Now(),
|
||||
ResetInterval: time.Minute * 10,
|
||||
Amount: 100,
|
||||
Limit: true,
|
||||
Notifications: QuotaNotifications{
|
||||
{
|
||||
Percent: 20,
|
||||
Repeat: false,
|
||||
CallURL: "%",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-bZ0Fj", ""))
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unit unimplemented",
|
||||
args: args{
|
||||
addQuota: &AddQuota{
|
||||
Unit: "unimplemented",
|
||||
From: time.Now(),
|
||||
ResetInterval: time.Minute * 10,
|
||||
Amount: 100,
|
||||
Limit: true,
|
||||
Notifications: nil,
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-OTeSh", ""))
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "amount 0",
|
||||
args: args{
|
||||
addQuota: &AddQuota{
|
||||
Unit: QuotaRequestsAllAuthenticated,
|
||||
From: time.Now(),
|
||||
ResetInterval: time.Minute * 10,
|
||||
Amount: 0,
|
||||
Limit: true,
|
||||
Notifications: nil,
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-hOKSJ", ""))
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "reset interval under 1 min",
|
||||
args: args{
|
||||
addQuota: &AddQuota{
|
||||
Unit: QuotaRequestsAllAuthenticated,
|
||||
From: time.Now(),
|
||||
ResetInterval: time.Second * 10,
|
||||
Amount: 100,
|
||||
Limit: true,
|
||||
Notifications: nil,
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-R5otd", ""))
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "validate, ok",
|
||||
args: args{
|
||||
addQuota: &AddQuota{
|
||||
Unit: QuotaRequestsAllAuthenticated,
|
||||
From: time.Now(),
|
||||
ResetInterval: time.Minute * 10,
|
||||
Amount: 100,
|
||||
Limit: false,
|
||||
Notifications: nil,
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
err: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.args.addQuota.validate()
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
if tt.res.err != nil && !tt.res.err(err) {
|
||||
t.Errorf("got wrong err: %v ", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -3,6 +3,7 @@ package database
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgtype"
|
||||
)
|
||||
@ -93,3 +94,15 @@ func (m Map[V]) Value() (driver.Value, error) {
|
||||
}
|
||||
return json.Marshal(m)
|
||||
}
|
||||
|
||||
type Duration time.Duration
|
||||
|
||||
// Scan implements the [database/sql.Scanner] interface.
|
||||
func (d *Duration) Scan(src any) error {
|
||||
interval := new(pgtype.Interval)
|
||||
if err := interval.Scan(src); err != nil {
|
||||
return err
|
||||
}
|
||||
*d = Duration(time.Duration(interval.Microseconds*1000) + time.Duration(interval.Days)*24*time.Hour + time.Duration(interval.Months)*30*24*time.Hour)
|
||||
return nil
|
||||
}
|
||||
|
@ -24,8 +24,8 @@ type DetailsMsg interface {
|
||||
//
|
||||
// The resource owner is compared with expected and is
|
||||
// therefore the only value that has to be set.
|
||||
func AssertDetails[D DetailsMsg](t testing.TB, exptected, actual D) {
|
||||
wantDetails, gotDetails := exptected.GetDetails(), actual.GetDetails()
|
||||
func AssertDetails[D DetailsMsg](t testing.TB, expected, actual D) {
|
||||
wantDetails, gotDetails := expected.GetDetails(), actual.GetDetails()
|
||||
if wantDetails == nil {
|
||||
assert.Nil(t, gotDetails)
|
||||
return
|
||||
|
@ -22,22 +22,17 @@ FirstInstance:
|
||||
PasswordChangeRequired: false
|
||||
|
||||
LogStore:
|
||||
Access:
|
||||
Database:
|
||||
Enabled: true
|
||||
Debounce:
|
||||
MinFrequency: 0s
|
||||
MaxBulkSize: 0
|
||||
Execution:
|
||||
Database:
|
||||
Enabled: true
|
||||
Stdout:
|
||||
Enabled: true
|
||||
|
||||
Quotas:
|
||||
Access:
|
||||
Enabled: true
|
||||
ExhaustedCookieKey: "zitadel.quota.limiting"
|
||||
ExhaustedCookieMaxAge: "60s"
|
||||
Execution:
|
||||
Enabled: true
|
||||
|
||||
Projections:
|
||||
Customizations:
|
||||
|
@ -8,7 +8,11 @@ import (
|
||||
_ "embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@ -30,6 +34,7 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
"github.com/zitadel/zitadel/internal/net"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/webauthn"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/admin"
|
||||
@ -71,6 +76,11 @@ const (
|
||||
UserPassword = "VeryS3cret!"
|
||||
)
|
||||
|
||||
const (
|
||||
PortMilestoneServer = "8081"
|
||||
PortQuotaServer = "8082"
|
||||
)
|
||||
|
||||
// User information with a Personal Access Token.
|
||||
type User struct {
|
||||
*query.User
|
||||
@ -101,6 +111,11 @@ type Tester struct {
|
||||
Organisation *query.Org
|
||||
Users InstanceUserMap
|
||||
|
||||
MilestoneChan chan []byte
|
||||
milestoneServer *httptest.Server
|
||||
QuotaNotificationChan chan []byte
|
||||
quotaNotificationServer *httptest.Server
|
||||
|
||||
Client Client
|
||||
WebAuthN *webauthn.Client
|
||||
wg sync.WaitGroup // used for shutdown
|
||||
@ -271,6 +286,8 @@ func (s *Tester) Done() {
|
||||
|
||||
s.Shutdown <- os.Interrupt
|
||||
s.wg.Wait()
|
||||
s.milestoneServer.Close()
|
||||
s.quotaNotificationServer.Close()
|
||||
}
|
||||
|
||||
// NewTester start a new Zitadel server by passing the default commandline.
|
||||
@ -279,13 +296,13 @@ func (s *Tester) Done() {
|
||||
// INTEGRATION_DB_FLAVOR environment variable and can have the values "cockroach"
|
||||
// or "postgres". Defaults to "cockroach".
|
||||
//
|
||||
// The deault Instance and Organisation are read from the DB and system
|
||||
// The default Instance and Organisation are read from the DB and system
|
||||
// users are created as needed.
|
||||
//
|
||||
// After the server is started, a [grpc.ClientConn] will be created and
|
||||
// the server is polled for it's health status.
|
||||
//
|
||||
// Note: the database must already be setup and intialized before
|
||||
// Note: the database must already be setup and initialized before
|
||||
// using NewTester. See the CONTRIBUTING.md document for details.
|
||||
func NewTester(ctx context.Context) *Tester {
|
||||
args := strings.Split(commandLine, " ")
|
||||
@ -311,6 +328,13 @@ func NewTester(ctx context.Context) *Tester {
|
||||
tester := Tester{
|
||||
Users: make(InstanceUserMap),
|
||||
}
|
||||
tester.MilestoneChan = make(chan []byte, 100)
|
||||
tester.milestoneServer, err = runMilestoneServer(ctx, tester.MilestoneChan)
|
||||
logging.OnError(err).Fatal()
|
||||
tester.QuotaNotificationChan = make(chan []byte, 100)
|
||||
tester.quotaNotificationServer, err = runQuotaServer(ctx, tester.QuotaNotificationChan)
|
||||
logging.OnError(err).Fatal()
|
||||
|
||||
tester.wg.Add(1)
|
||||
go func(wg *sync.WaitGroup) {
|
||||
logging.OnError(cmd.Execute()).Fatal()
|
||||
@ -328,7 +352,6 @@ func NewTester(ctx context.Context) *Tester {
|
||||
tester.createMachineUserOrgOwner(ctx)
|
||||
tester.createMachineUserInstanceOwner(ctx)
|
||||
tester.WebAuthN = webauthn.NewClient(tester.Config.WebAuthNName, tester.Config.ExternalDomain, "https://"+tester.Host())
|
||||
|
||||
return &tester
|
||||
}
|
||||
|
||||
@ -338,3 +361,51 @@ func Contexts(timeout time.Duration) (ctx, errCtx context.Context, cancel contex
|
||||
ctx, cancel = context.WithTimeout(context.Background(), timeout)
|
||||
return ctx, errCtx, cancel
|
||||
}
|
||||
|
||||
func runMilestoneServer(ctx context.Context, bodies chan []byte) (*httptest.Server, error) {
|
||||
mockServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if r.Header.Get("single-value") != "single-value" {
|
||||
http.Error(w, "single-value header not set", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if reflect.DeepEqual(r.Header.Get("multi-value"), "multi-value-1,multi-value-2") {
|
||||
http.Error(w, "single-value header not set", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
bodies <- body
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
config := net.ListenConfig()
|
||||
listener, err := config.Listen(ctx, "tcp", ":"+PortMilestoneServer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mockServer.Listener = listener
|
||||
mockServer.Start()
|
||||
return mockServer, nil
|
||||
}
|
||||
|
||||
func runQuotaServer(ctx context.Context, bodies chan []byte) (*httptest.Server, error) {
|
||||
mockServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
bodies <- body
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
config := net.ListenConfig()
|
||||
listener, err := config.Listen(ctx, "tcp", ":"+PortQuotaServer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mockServer.Listener = listener
|
||||
mockServer.Start()
|
||||
return mockServer, nil
|
||||
}
|
||||
|
@ -6,6 +6,9 @@ type Configs struct {
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Database *EmitterConfig
|
||||
Stdout *EmitterConfig
|
||||
Stdout *StdConfig
|
||||
}
|
||||
|
||||
type StdConfig struct {
|
||||
Enabled bool
|
||||
}
|
||||
|
@ -9,19 +9,17 @@ import (
|
||||
"github.com/zitadel/logging"
|
||||
)
|
||||
|
||||
type bulkSink interface {
|
||||
sendBulk(ctx context.Context, bulk []LogRecord) error
|
||||
type bulkSink[T LogRecord[T]] interface {
|
||||
SendBulk(ctx context.Context, bulk []T) error
|
||||
}
|
||||
|
||||
var _ bulkSink = bulkSinkFunc(nil)
|
||||
type bulkSinkFunc[T LogRecord[T]] func(ctx context.Context, bulk []T) error
|
||||
|
||||
type bulkSinkFunc func(ctx context.Context, items []LogRecord) error
|
||||
|
||||
func (s bulkSinkFunc) sendBulk(ctx context.Context, items []LogRecord) error {
|
||||
return s(ctx, items)
|
||||
func (s bulkSinkFunc[T]) SendBulk(ctx context.Context, bulk []T) error {
|
||||
return s(ctx, bulk)
|
||||
}
|
||||
|
||||
type debouncer struct {
|
||||
type debouncer[T LogRecord[T]] struct {
|
||||
// Storing context.Context in a struct is generally bad practice
|
||||
// https://go.dev/blog/context-and-structs
|
||||
// However, debouncer starts a go routine that triggers side effects itself.
|
||||
@ -33,8 +31,8 @@ type debouncer struct {
|
||||
ticker *clock.Ticker
|
||||
mux sync.Mutex
|
||||
cfg DebouncerConfig
|
||||
storage bulkSink
|
||||
cache []LogRecord
|
||||
storage bulkSink[T]
|
||||
cache []T
|
||||
cacheLen uint
|
||||
}
|
||||
|
||||
@ -43,8 +41,8 @@ type DebouncerConfig struct {
|
||||
MaxBulkSize uint
|
||||
}
|
||||
|
||||
func newDebouncer(binarySignaledCtx context.Context, cfg DebouncerConfig, clock clock.Clock, ship bulkSink) *debouncer {
|
||||
a := &debouncer{
|
||||
func newDebouncer[T LogRecord[T]](binarySignaledCtx context.Context, cfg DebouncerConfig, clock clock.Clock, ship bulkSink[T]) *debouncer[T] {
|
||||
a := &debouncer[T]{
|
||||
binarySignaledCtx: binarySignaledCtx,
|
||||
clock: clock,
|
||||
cfg: cfg,
|
||||
@ -58,7 +56,7 @@ func newDebouncer(binarySignaledCtx context.Context, cfg DebouncerConfig, clock
|
||||
return a
|
||||
}
|
||||
|
||||
func (d *debouncer) add(item LogRecord) {
|
||||
func (d *debouncer[T]) add(item T) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.cache = append(d.cache, item)
|
||||
@ -69,13 +67,13 @@ func (d *debouncer) add(item LogRecord) {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *debouncer) ship() {
|
||||
func (d *debouncer[T]) ship() {
|
||||
if d.cacheLen == 0 {
|
||||
return
|
||||
}
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
if err := d.storage.sendBulk(d.binarySignaledCtx, d.cache); err != nil {
|
||||
if err := d.storage.SendBulk(d.binarySignaledCtx, d.cache); err != nil {
|
||||
logging.WithError(err).WithField("size", len(d.cache)).Error("storing bulk failed")
|
||||
}
|
||||
d.cache = nil
|
||||
@ -85,7 +83,7 @@ func (d *debouncer) ship() {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *debouncer) shipOnTicks() {
|
||||
func (d *debouncer[T]) shipOnTicks() {
|
||||
for range d.ticker.C {
|
||||
d.ship()
|
||||
}
|
||||
|
@ -2,56 +2,52 @@ package logstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/benbjohnson/clock"
|
||||
"github.com/zitadel/logging"
|
||||
)
|
||||
|
||||
type EmitterConfig struct {
|
||||
Enabled bool
|
||||
Keep time.Duration
|
||||
CleanupInterval time.Duration
|
||||
Debounce *DebouncerConfig
|
||||
}
|
||||
|
||||
type emitter struct {
|
||||
type emitter[T LogRecord[T]] struct {
|
||||
enabled bool
|
||||
ctx context.Context
|
||||
debouncer *debouncer
|
||||
emitter LogEmitter
|
||||
debouncer *debouncer[T]
|
||||
emitter LogEmitter[T]
|
||||
clock clock.Clock
|
||||
}
|
||||
|
||||
type LogRecord interface {
|
||||
Normalize() LogRecord
|
||||
type LogRecord[T any] interface {
|
||||
Normalize() T
|
||||
}
|
||||
|
||||
type LogRecordFunc func() LogRecord
|
||||
type LogRecordFunc[T any] func() T
|
||||
|
||||
func (r LogRecordFunc) Normalize() LogRecord {
|
||||
func (r LogRecordFunc[T]) Normalize() T {
|
||||
return r()
|
||||
}
|
||||
|
||||
type LogEmitter interface {
|
||||
Emit(ctx context.Context, bulk []LogRecord) error
|
||||
type LogEmitter[T LogRecord[T]] interface {
|
||||
Emit(ctx context.Context, bulk []T) error
|
||||
}
|
||||
|
||||
type LogEmitterFunc func(ctx context.Context, bulk []LogRecord) error
|
||||
type LogEmitterFunc[T LogRecord[T]] func(ctx context.Context, bulk []T) error
|
||||
|
||||
func (l LogEmitterFunc) Emit(ctx context.Context, bulk []LogRecord) error {
|
||||
func (l LogEmitterFunc[T]) Emit(ctx context.Context, bulk []T) error {
|
||||
return l(ctx, bulk)
|
||||
}
|
||||
|
||||
type LogCleanupper interface {
|
||||
LogEmitter
|
||||
type LogCleanupper[T LogRecord[T]] interface {
|
||||
Cleanup(ctx context.Context, keep time.Duration) error
|
||||
LogEmitter[T]
|
||||
}
|
||||
|
||||
// NewEmitter accepts Clock from github.com/benbjohnson/clock so we can control timers and tickers in the unit tests
|
||||
func NewEmitter(ctx context.Context, clock clock.Clock, cfg *EmitterConfig, logger LogEmitter) (*emitter, error) {
|
||||
svc := &emitter{
|
||||
func NewEmitter[T LogRecord[T]](ctx context.Context, clock clock.Clock, cfg *EmitterConfig, logger LogEmitter[T]) (*emitter[T], error) {
|
||||
svc := &emitter[T]{
|
||||
enabled: cfg != nil && cfg.Enabled,
|
||||
ctx: ctx,
|
||||
emitter: logger,
|
||||
@ -63,36 +59,12 @@ func NewEmitter(ctx context.Context, clock clock.Clock, cfg *EmitterConfig, logg
|
||||
}
|
||||
|
||||
if cfg.Debounce != nil && (cfg.Debounce.MinFrequency > 0 || cfg.Debounce.MaxBulkSize > 0) {
|
||||
svc.debouncer = newDebouncer(ctx, *cfg.Debounce, clock, newStorageBulkSink(svc.emitter))
|
||||
}
|
||||
|
||||
cleanupper, ok := logger.(LogCleanupper)
|
||||
if !ok {
|
||||
if cfg.Keep != 0 {
|
||||
return nil, fmt.Errorf("cleaning up for this storage type is not supported, so keep duration must be 0, but is %d", cfg.Keep)
|
||||
}
|
||||
if cfg.CleanupInterval != 0 {
|
||||
return nil, fmt.Errorf("cleaning up for this storage type is not supported, so cleanup interval duration must be 0, but is %d", cfg.Keep)
|
||||
}
|
||||
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
if cfg.Keep != 0 && cfg.CleanupInterval != 0 {
|
||||
go svc.startCleanupping(cleanupper, cfg.CleanupInterval, cfg.Keep)
|
||||
svc.debouncer = newDebouncer[T](ctx, *cfg.Debounce, clock, newStorageBulkSink(svc.emitter))
|
||||
}
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
func (s *emitter) startCleanupping(cleanupper LogCleanupper, cleanupInterval, keep time.Duration) {
|
||||
for range s.clock.Tick(cleanupInterval) {
|
||||
if err := cleanupper.Cleanup(s.ctx, keep); err != nil {
|
||||
logging.WithError(err).Error("cleaning up logs failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *emitter) Emit(ctx context.Context, record LogRecord) (err error) {
|
||||
func (s *emitter[T]) Emit(ctx context.Context, record T) (err error) {
|
||||
if !s.enabled {
|
||||
return nil
|
||||
}
|
||||
@ -102,11 +74,11 @@ func (s *emitter) Emit(ctx context.Context, record LogRecord) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.emitter.Emit(ctx, []LogRecord{record})
|
||||
return s.emitter.Emit(ctx, []T{record})
|
||||
}
|
||||
|
||||
func newStorageBulkSink(emitter LogEmitter) bulkSinkFunc {
|
||||
return func(ctx context.Context, bulk []LogRecord) error {
|
||||
func newStorageBulkSink[T LogRecord[T]](emitter LogEmitter[T]) bulkSinkFunc[T] {
|
||||
return func(ctx context.Context, bulk []T) error {
|
||||
return emitter.Emit(ctx, bulk)
|
||||
}
|
||||
}
|
||||
|
@ -3,167 +3,91 @@ package access
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/zitadel/logging"
|
||||
"google.golang.org/grpc/codes"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/call"
|
||||
zitadel_http "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
caos_errors "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
"github.com/zitadel/zitadel/internal/logstore/record"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/query/projection"
|
||||
"github.com/zitadel/zitadel/internal/repository/quota"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
const (
|
||||
accessLogsTable = "logstore.access"
|
||||
accessTimestampCol = "log_date"
|
||||
accessProtocolCol = "protocol"
|
||||
accessRequestURLCol = "request_url"
|
||||
accessResponseStatusCol = "response_status"
|
||||
accessRequestHeadersCol = "request_headers"
|
||||
accessResponseHeadersCol = "response_headers"
|
||||
accessInstanceIdCol = "instance_id"
|
||||
accessProjectIdCol = "project_id"
|
||||
accessRequestedDomainCol = "requested_domain"
|
||||
accessRequestedHostCol = "requested_host"
|
||||
)
|
||||
|
||||
var _ logstore.UsageQuerier = (*databaseLogStorage)(nil)
|
||||
var _ logstore.LogCleanupper = (*databaseLogStorage)(nil)
|
||||
var _ logstore.UsageStorer[*record.AccessLog] = (*databaseLogStorage)(nil)
|
||||
|
||||
type databaseLogStorage struct {
|
||||
dbClient *database.DB
|
||||
commands *command.Commands
|
||||
queries *query.Queries
|
||||
}
|
||||
|
||||
func NewDatabaseLogStorage(dbClient *database.DB) *databaseLogStorage {
|
||||
return &databaseLogStorage{dbClient: dbClient}
|
||||
func NewDatabaseLogStorage(dbClient *database.DB, commands *command.Commands, queries *query.Queries) *databaseLogStorage {
|
||||
return &databaseLogStorage{dbClient: dbClient, commands: commands, queries: queries}
|
||||
}
|
||||
|
||||
func (l *databaseLogStorage) QuotaUnit() quota.Unit {
|
||||
return quota.RequestsAllAuthenticated
|
||||
}
|
||||
|
||||
func (l *databaseLogStorage) Emit(ctx context.Context, bulk []logstore.LogRecord) error {
|
||||
func (l *databaseLogStorage) Emit(ctx context.Context, bulk []*record.AccessLog) error {
|
||||
if len(bulk) == 0 {
|
||||
return nil
|
||||
}
|
||||
builder := squirrel.Insert(accessLogsTable).
|
||||
Columns(
|
||||
accessTimestampCol,
|
||||
accessProtocolCol,
|
||||
accessRequestURLCol,
|
||||
accessResponseStatusCol,
|
||||
accessRequestHeadersCol,
|
||||
accessResponseHeadersCol,
|
||||
accessInstanceIdCol,
|
||||
accessProjectIdCol,
|
||||
accessRequestedDomainCol,
|
||||
accessRequestedHostCol,
|
||||
).
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
for idx := range bulk {
|
||||
item := bulk[idx].(*Record)
|
||||
builder = builder.Values(
|
||||
item.LogDate,
|
||||
item.Protocol,
|
||||
item.RequestURL,
|
||||
item.ResponseStatus,
|
||||
item.RequestHeaders,
|
||||
item.ResponseHeaders,
|
||||
item.InstanceID,
|
||||
item.ProjectID,
|
||||
item.RequestedDomain,
|
||||
item.RequestedHost,
|
||||
)
|
||||
}
|
||||
|
||||
stmt, args, err := builder.ToSql()
|
||||
if err != nil {
|
||||
return caos_errors.ThrowInternal(err, "ACCESS-KOS7I", "Errors.Internal")
|
||||
}
|
||||
|
||||
result, err := l.dbClient.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return caos_errors.ThrowInternal(err, "ACCESS-alnT9", "Errors.Access.StorageFailed")
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return caos_errors.ThrowInternal(err, "ACCESS-7KIpL", "Errors.Internal")
|
||||
}
|
||||
|
||||
logging.WithFields("rows", rows).Debug("successfully stored access logs")
|
||||
return nil
|
||||
return l.incrementUsage(ctx, bulk)
|
||||
}
|
||||
|
||||
func (l *databaseLogStorage) QueryUsage(ctx context.Context, instanceId string, start time.Time) (uint64, error) {
|
||||
stmt, args, err := squirrel.Select(
|
||||
fmt.Sprintf("count(%s)", accessInstanceIdCol),
|
||||
).
|
||||
From(accessLogsTable + l.dbClient.Timetravel(call.Took(ctx))).
|
||||
Where(squirrel.And{
|
||||
squirrel.Eq{accessInstanceIdCol: instanceId},
|
||||
squirrel.GtOrEq{accessTimestampCol: start},
|
||||
squirrel.Expr(fmt.Sprintf(`%s #>> '{%s,0}' = '[REDACTED]'`, accessRequestHeadersCol, strings.ToLower(zitadel_http.Authorization))),
|
||||
squirrel.NotLike{accessRequestURLCol: "%/zitadel.system.v1.SystemService/%"},
|
||||
squirrel.NotLike{accessRequestURLCol: "%/system/v1/%"},
|
||||
squirrel.Or{
|
||||
squirrel.And{
|
||||
squirrel.Eq{accessProtocolCol: HTTP},
|
||||
squirrel.NotEq{accessResponseStatusCol: http.StatusForbidden},
|
||||
squirrel.NotEq{accessResponseStatusCol: http.StatusInternalServerError},
|
||||
squirrel.NotEq{accessResponseStatusCol: http.StatusTooManyRequests},
|
||||
},
|
||||
squirrel.And{
|
||||
squirrel.Eq{accessProtocolCol: GRPC},
|
||||
squirrel.NotEq{accessResponseStatusCol: codes.PermissionDenied},
|
||||
squirrel.NotEq{accessResponseStatusCol: codes.Internal},
|
||||
squirrel.NotEq{accessResponseStatusCol: codes.ResourceExhausted},
|
||||
},
|
||||
},
|
||||
}).
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
ToSql()
|
||||
func (l *databaseLogStorage) incrementUsage(ctx context.Context, bulk []*record.AccessLog) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
if err != nil {
|
||||
return 0, caos_errors.ThrowInternal(err, "ACCESS-V9Sde", "Errors.Internal")
|
||||
byInstance := make(map[string][]*record.AccessLog)
|
||||
for _, r := range bulk {
|
||||
if r.InstanceID != "" {
|
||||
byInstance[r.InstanceID] = append(byInstance[r.InstanceID], r)
|
||||
}
|
||||
|
||||
var count uint64
|
||||
err = l.dbClient.
|
||||
QueryRowContext(ctx,
|
||||
func(row *sql.Row) error {
|
||||
return row.Scan(&count)
|
||||
},
|
||||
stmt, args...,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return 0, caos_errors.ThrowInternal(err, "ACCESS-pBPrM", "Errors.Logstore.Access.ScanFailed")
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (l *databaseLogStorage) Cleanup(ctx context.Context, keep time.Duration) error {
|
||||
stmt, args, err := squirrel.Delete(accessLogsTable).
|
||||
Where(squirrel.LtOrEq{accessTimestampCol: time.Now().Add(-keep)}).
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
ToSql()
|
||||
|
||||
if err != nil {
|
||||
return caos_errors.ThrowInternal(err, "ACCESS-2oTh6", "Errors.Internal")
|
||||
for instanceID, instanceBulk := range byInstance {
|
||||
q, getQuotaErr := l.queries.GetQuota(ctx, instanceID, quota.RequestsAllAuthenticated)
|
||||
if errors.Is(getQuotaErr, sql.ErrNoRows) {
|
||||
continue
|
||||
}
|
||||
err = errors.Join(err, getQuotaErr)
|
||||
if getQuotaErr != nil {
|
||||
continue
|
||||
}
|
||||
sum, incrementErr := l.incrementUsageFromAccessLogs(ctx, instanceID, q.CurrentPeriodStart, instanceBulk)
|
||||
err = errors.Join(err, incrementErr)
|
||||
if incrementErr != nil {
|
||||
continue
|
||||
}
|
||||
notifications, getNotificationErr := l.queries.GetDueQuotaNotifications(ctx, instanceID, quota.RequestsAllAuthenticated, q, q.CurrentPeriodStart, sum)
|
||||
err = errors.Join(err, getNotificationErr)
|
||||
if getNotificationErr != nil || len(notifications) == 0 {
|
||||
continue
|
||||
}
|
||||
ctx = authz.WithInstanceID(ctx, instanceID)
|
||||
reportErr := l.commands.ReportQuotaUsage(ctx, notifications)
|
||||
err = errors.Join(err, reportErr)
|
||||
if reportErr != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
execCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
_, err = l.dbClient.ExecContext(execCtx, stmt, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
func (l *databaseLogStorage) incrementUsageFromAccessLogs(ctx context.Context, instanceID string, periodStart time.Time, records []*record.AccessLog) (sum uint64, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
var count uint64
|
||||
for _, r := range records {
|
||||
if r.IsAuthenticated() {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return projection.QuotaProjection.IncrementUsage(ctx, quota.RequestsAllAuthenticated, instanceID, periodStart, count)
|
||||
}
|
||||
|
@ -1,97 +0,0 @@
|
||||
package access
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
zitadel_http "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
)
|
||||
|
||||
var _ logstore.LogRecord = (*Record)(nil)
|
||||
|
||||
type Record struct {
|
||||
LogDate time.Time `json:"logDate"`
|
||||
Protocol Protocol `json:"protocol"`
|
||||
RequestURL string `json:"requestUrl"`
|
||||
ResponseStatus uint32 `json:"responseStatus"`
|
||||
// RequestHeaders are plain maps so varying implementation
|
||||
// between HTTP and gRPC don't interfere with each other
|
||||
RequestHeaders map[string][]string `json:"requestHeaders"`
|
||||
// ResponseHeaders are plain maps so varying implementation
|
||||
// between HTTP and gRPC don't interfere with each other
|
||||
ResponseHeaders map[string][]string `json:"responseHeaders"`
|
||||
InstanceID string `json:"instanceId"`
|
||||
ProjectID string `json:"projectId"`
|
||||
RequestedDomain string `json:"requestedDomain"`
|
||||
RequestedHost string `json:"requestedHost"`
|
||||
}
|
||||
|
||||
type Protocol uint8
|
||||
|
||||
const (
|
||||
GRPC Protocol = iota
|
||||
HTTP
|
||||
|
||||
redacted = "[REDACTED]"
|
||||
)
|
||||
|
||||
func (a Record) Normalize() logstore.LogRecord {
|
||||
a.RequestedDomain = cutString(a.RequestedDomain, 200)
|
||||
a.RequestURL = cutString(a.RequestURL, 200)
|
||||
a.RequestHeaders = normalizeHeaders(a.RequestHeaders, strings.ToLower(zitadel_http.Authorization), "grpcgateway-authorization", "cookie", "grpcgateway-cookie")
|
||||
a.ResponseHeaders = normalizeHeaders(a.ResponseHeaders, "set-cookie")
|
||||
return &a
|
||||
}
|
||||
|
||||
// normalizeHeaders lowers all header keys and redacts secrets
|
||||
func normalizeHeaders(header map[string][]string, redactKeysLower ...string) map[string][]string {
|
||||
return pruneKeys(redactKeys(lowerKeys(header), redactKeysLower...))
|
||||
}
|
||||
|
||||
func lowerKeys(header map[string][]string) map[string][]string {
|
||||
lower := make(map[string][]string, len(header))
|
||||
for k, v := range header {
|
||||
lower[strings.ToLower(k)] = v
|
||||
}
|
||||
return lower
|
||||
}
|
||||
|
||||
func redactKeys(header map[string][]string, redactKeysLower ...string) map[string][]string {
|
||||
redactedKeys := make(map[string][]string, len(header))
|
||||
for k, v := range header {
|
||||
redactedKeys[k] = v
|
||||
}
|
||||
for _, redactKey := range redactKeysLower {
|
||||
if _, ok := redactedKeys[redactKey]; ok {
|
||||
redactedKeys[redactKey] = []string{redacted}
|
||||
}
|
||||
}
|
||||
return redactedKeys
|
||||
}
|
||||
|
||||
const maxValuesPerKey = 10
|
||||
|
||||
func pruneKeys(header map[string][]string) map[string][]string {
|
||||
prunedKeys := make(map[string][]string, len(header))
|
||||
for key, value := range header {
|
||||
valueItems := make([]string, 0, maxValuesPerKey)
|
||||
for i, valueItem := range value {
|
||||
// Max 10 header values per key
|
||||
if i > maxValuesPerKey {
|
||||
break
|
||||
}
|
||||
// Max 200 value length
|
||||
valueItems = append(valueItems, cutString(valueItem, 200))
|
||||
}
|
||||
prunedKeys[key] = valueItems
|
||||
}
|
||||
return prunedKeys
|
||||
}
|
||||
|
||||
func cutString(str string, pos int) string {
|
||||
if len(str) <= pos {
|
||||
return str
|
||||
}
|
||||
return str[:pos-1]
|
||||
}
|
@ -3,142 +3,84 @@ package execution
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"errors"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/call"
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
caos_errors "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
"github.com/zitadel/zitadel/internal/logstore/record"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/query/projection"
|
||||
"github.com/zitadel/zitadel/internal/repository/quota"
|
||||
)
|
||||
|
||||
const (
|
||||
executionLogsTable = "logstore.execution"
|
||||
executionTimestampCol = "log_date"
|
||||
executionTookCol = "took"
|
||||
executionMessageCol = "message"
|
||||
executionLogLevelCol = "loglevel"
|
||||
executionInstanceIdCol = "instance_id"
|
||||
executionActionIdCol = "action_id"
|
||||
executionMetadataCol = "metadata"
|
||||
)
|
||||
|
||||
var _ logstore.UsageQuerier = (*databaseLogStorage)(nil)
|
||||
var _ logstore.LogCleanupper = (*databaseLogStorage)(nil)
|
||||
var _ logstore.UsageStorer[*record.ExecutionLog] = (*databaseLogStorage)(nil)
|
||||
|
||||
type databaseLogStorage struct {
|
||||
dbClient *database.DB
|
||||
commands *command.Commands
|
||||
queries *query.Queries
|
||||
}
|
||||
|
||||
func NewDatabaseLogStorage(dbClient *database.DB) *databaseLogStorage {
|
||||
return &databaseLogStorage{dbClient: dbClient}
|
||||
func NewDatabaseLogStorage(dbClient *database.DB, commands *command.Commands, queries *query.Queries) *databaseLogStorage {
|
||||
return &databaseLogStorage{dbClient: dbClient, commands: commands, queries: queries}
|
||||
}
|
||||
|
||||
func (l *databaseLogStorage) QuotaUnit() quota.Unit {
|
||||
return quota.ActionsAllRunsSeconds
|
||||
}
|
||||
|
||||
func (l *databaseLogStorage) Emit(ctx context.Context, bulk []logstore.LogRecord) error {
|
||||
func (l *databaseLogStorage) Emit(ctx context.Context, bulk []*record.ExecutionLog) error {
|
||||
if len(bulk) == 0 {
|
||||
return nil
|
||||
}
|
||||
builder := squirrel.Insert(executionLogsTable).
|
||||
Columns(
|
||||
executionTimestampCol,
|
||||
executionTookCol,
|
||||
executionMessageCol,
|
||||
executionLogLevelCol,
|
||||
executionInstanceIdCol,
|
||||
executionActionIdCol,
|
||||
executionMetadataCol,
|
||||
).
|
||||
PlaceholderFormat(squirrel.Dollar)
|
||||
|
||||
for idx := range bulk {
|
||||
item := bulk[idx].(*Record)
|
||||
|
||||
var took interface{}
|
||||
if item.Took > 0 {
|
||||
took = item.Took
|
||||
}
|
||||
|
||||
builder = builder.Values(
|
||||
item.LogDate,
|
||||
took,
|
||||
item.Message,
|
||||
item.LogLevel,
|
||||
item.InstanceID,
|
||||
item.ActionID,
|
||||
item.Metadata,
|
||||
)
|
||||
}
|
||||
|
||||
stmt, args, err := builder.ToSql()
|
||||
if err != nil {
|
||||
return caos_errors.ThrowInternal(err, "EXEC-KOS7I", "Errors.Internal")
|
||||
}
|
||||
|
||||
result, err := l.dbClient.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return caos_errors.ThrowInternal(err, "EXEC-0j6i5", "Errors.Access.StorageFailed")
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return caos_errors.ThrowInternal(err, "EXEC-MGchJ", "Errors.Internal")
|
||||
}
|
||||
|
||||
logging.WithFields("rows", rows).Debug("successfully stored execution logs")
|
||||
return nil
|
||||
return l.incrementUsage(ctx, bulk)
|
||||
}
|
||||
|
||||
func (l *databaseLogStorage) QueryUsage(ctx context.Context, instanceId string, start time.Time) (uint64, error) {
|
||||
stmt, args, err := squirrel.Select(
|
||||
fmt.Sprintf("COALESCE(SUM(%s)::INT,0)", executionTookCol),
|
||||
).
|
||||
From(executionLogsTable + l.dbClient.Timetravel(call.Took(ctx))).
|
||||
Where(squirrel.And{
|
||||
squirrel.Eq{executionInstanceIdCol: instanceId},
|
||||
squirrel.GtOrEq{executionTimestampCol: start},
|
||||
squirrel.NotEq{executionTookCol: nil},
|
||||
}).
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
ToSql()
|
||||
|
||||
if err != nil {
|
||||
return 0, caos_errors.ThrowInternal(err, "EXEC-DXtzg", "Errors.Internal")
|
||||
func (l *databaseLogStorage) incrementUsage(ctx context.Context, bulk []*record.ExecutionLog) (err error) {
|
||||
byInstance := make(map[string][]*record.ExecutionLog)
|
||||
for _, r := range bulk {
|
||||
if r.InstanceID != "" {
|
||||
byInstance[r.InstanceID] = append(byInstance[r.InstanceID], r)
|
||||
}
|
||||
}
|
||||
for instanceID, instanceBulk := range byInstance {
|
||||
q, getQuotaErr := l.queries.GetQuota(ctx, instanceID, quota.ActionsAllRunsSeconds)
|
||||
if errors.Is(getQuotaErr, sql.ErrNoRows) {
|
||||
continue
|
||||
}
|
||||
err = errors.Join(err, getQuotaErr)
|
||||
if getQuotaErr != nil {
|
||||
continue
|
||||
}
|
||||
sum, incrementErr := l.incrementUsageFromExecutionLogs(ctx, instanceID, q.CurrentPeriodStart, instanceBulk)
|
||||
err = errors.Join(err, incrementErr)
|
||||
if incrementErr != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var durationSeconds uint64
|
||||
err = l.dbClient.
|
||||
QueryRowContext(ctx,
|
||||
func(row *sql.Row) error {
|
||||
return row.Scan(&durationSeconds)
|
||||
},
|
||||
stmt, args...,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, caos_errors.ThrowInternal(err, "EXEC-Ad8nP", "Errors.Logstore.Execution.ScanFailed")
|
||||
notifications, getNotificationErr := l.queries.GetDueQuotaNotifications(ctx, instanceID, quota.ActionsAllRunsSeconds, q, q.CurrentPeriodStart, sum)
|
||||
err = errors.Join(err, getNotificationErr)
|
||||
if getNotificationErr != nil || len(notifications) == 0 {
|
||||
continue
|
||||
}
|
||||
ctx = authz.WithInstanceID(ctx, instanceID)
|
||||
reportErr := l.commands.ReportQuotaUsage(ctx, notifications)
|
||||
err = errors.Join(err, reportErr)
|
||||
if reportErr != nil {
|
||||
continue
|
||||
}
|
||||
return durationSeconds, nil
|
||||
}
|
||||
|
||||
func (l *databaseLogStorage) Cleanup(ctx context.Context, keep time.Duration) error {
|
||||
stmt, args, err := squirrel.Delete(executionLogsTable).
|
||||
Where(squirrel.LtOrEq{executionTimestampCol: time.Now().Add(-keep)}).
|
||||
PlaceholderFormat(squirrel.Dollar).
|
||||
ToSql()
|
||||
|
||||
if err != nil {
|
||||
return caos_errors.ThrowInternal(err, "EXEC-Bja8V", "Errors.Internal")
|
||||
}
|
||||
|
||||
execCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
_, err = l.dbClient.ExecContext(execCtx, stmt, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
func (l *databaseLogStorage) incrementUsageFromExecutionLogs(ctx context.Context, instanceID string, periodStart time.Time, records []*record.ExecutionLog) (sum uint64, err error) {
|
||||
var total time.Duration
|
||||
for _, r := range records {
|
||||
total += r.Took
|
||||
}
|
||||
return projection.QuotaProjection.IncrementUsage(ctx, quota.ActionsAllRunsSeconds, instanceID, periodStart, uint64(math.Floor(total.Seconds())))
|
||||
}
|
||||
|
@ -1,89 +0,0 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/benbjohnson/clock"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/quota"
|
||||
)
|
||||
|
||||
var _ logstore.UsageQuerier = (*InmemLogStorage)(nil)
|
||||
var _ logstore.LogCleanupper = (*InmemLogStorage)(nil)
|
||||
|
||||
type InmemLogStorage struct {
|
||||
mux sync.Mutex
|
||||
clock clock.Clock
|
||||
emitted []*record
|
||||
bulks []int
|
||||
}
|
||||
|
||||
func NewInMemoryStorage(clock clock.Clock) *InmemLogStorage {
|
||||
return &InmemLogStorage{
|
||||
clock: clock,
|
||||
emitted: make([]*record, 0),
|
||||
bulks: make([]int, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *InmemLogStorage) QuotaUnit() quota.Unit {
|
||||
return quota.Unimplemented
|
||||
}
|
||||
|
||||
func (l *InmemLogStorage) Emit(_ context.Context, bulk []logstore.LogRecord) error {
|
||||
if len(bulk) == 0 {
|
||||
return nil
|
||||
}
|
||||
l.mux.Lock()
|
||||
defer l.mux.Unlock()
|
||||
for idx := range bulk {
|
||||
l.emitted = append(l.emitted, bulk[idx].(*record))
|
||||
}
|
||||
l.bulks = append(l.bulks, len(bulk))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *InmemLogStorage) QueryUsage(_ context.Context, _ string, start time.Time) (uint64, error) {
|
||||
l.mux.Lock()
|
||||
defer l.mux.Unlock()
|
||||
|
||||
var count uint64
|
||||
for _, r := range l.emitted {
|
||||
if r.ts.After(start) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (l *InmemLogStorage) Cleanup(_ context.Context, keep time.Duration) error {
|
||||
l.mux.Lock()
|
||||
defer l.mux.Unlock()
|
||||
|
||||
clean := make([]*record, 0)
|
||||
from := l.clock.Now().Add(-(keep + 1))
|
||||
for _, r := range l.emitted {
|
||||
if r.ts.After(from) {
|
||||
clean = append(clean, r)
|
||||
}
|
||||
}
|
||||
l.emitted = clean
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *InmemLogStorage) Bulks() []int {
|
||||
l.mux.Lock()
|
||||
defer l.mux.Unlock()
|
||||
|
||||
return l.bulks
|
||||
}
|
||||
|
||||
func (l *InmemLogStorage) Len() int {
|
||||
l.mux.Lock()
|
||||
defer l.mux.Unlock()
|
||||
|
||||
return len(l.emitted)
|
||||
}
|
@ -9,8 +9,8 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
)
|
||||
|
||||
func NewStdoutEmitter() logstore.LogEmitter {
|
||||
return logstore.LogEmitterFunc(func(ctx context.Context, bulk []logstore.LogRecord) error {
|
||||
func NewStdoutEmitter[T logstore.LogRecord[T]]() logstore.LogEmitter[T] {
|
||||
return logstore.LogEmitterFunc[T](func(ctx context.Context, bulk []T) error {
|
||||
for idx := range bulk {
|
||||
bytes, err := json.Marshal(bulk[idx])
|
||||
if err != nil {
|
||||
|
@ -4,7 +4,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/quota"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
)
|
||||
|
||||
type emitterOption func(config *logstore.EmitterConfig)
|
||||
@ -12,8 +12,6 @@ type emitterOption func(config *logstore.EmitterConfig)
|
||||
func emitterConfig(options ...emitterOption) *logstore.EmitterConfig {
|
||||
cfg := &logstore.EmitterConfig{
|
||||
Enabled: true,
|
||||
Keep: time.Hour,
|
||||
CleanupInterval: time.Hour,
|
||||
Debounce: &logstore.DebouncerConfig{
|
||||
MinFrequency: 0,
|
||||
MaxBulkSize: 0,
|
||||
@ -37,17 +35,10 @@ func withDisabled() emitterOption {
|
||||
}
|
||||
}
|
||||
|
||||
func withCleanupping(keep, interval time.Duration) emitterOption {
|
||||
return func(c *logstore.EmitterConfig) {
|
||||
c.Keep = keep
|
||||
c.CleanupInterval = interval
|
||||
}
|
||||
}
|
||||
type quotaOption func(config *query.Quota)
|
||||
|
||||
type quotaOption func(config *quota.AddedEvent)
|
||||
|
||||
func quotaConfig(quotaOptions ...quotaOption) quota.AddedEvent {
|
||||
q := "a.AddedEvent{
|
||||
func quotaConfig(quotaOptions ...quotaOption) *query.Quota {
|
||||
q := &query.Quota{
|
||||
Amount: 90,
|
||||
Limit: false,
|
||||
ResetInterval: 90 * time.Second,
|
||||
@ -56,18 +47,18 @@ func quotaConfig(quotaOptions ...quotaOption) quota.AddedEvent {
|
||||
for _, opt := range quotaOptions {
|
||||
opt(q)
|
||||
}
|
||||
return *q
|
||||
return q
|
||||
}
|
||||
|
||||
func withAmountAndInterval(n uint64) quotaOption {
|
||||
return func(c *quota.AddedEvent) {
|
||||
return func(c *query.Quota) {
|
||||
c.Amount = n
|
||||
c.ResetInterval = time.Duration(n) * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
func withLimiting() quotaOption {
|
||||
return func(c *quota.AddedEvent) {
|
||||
return func(c *query.Quota) {
|
||||
c.Limit = true
|
||||
}
|
||||
}
|
||||
|
120
internal/logstore/mock/inmem.go
Normal file
120
internal/logstore/mock/inmem.go
Normal file
@ -0,0 +1,120 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/benbjohnson/clock"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/repository/quota"
|
||||
)
|
||||
|
||||
var _ logstore.UsageStorer[*Record] = (*InmemLogStorage)(nil)
|
||||
var _ logstore.LogCleanupper[*Record] = (*InmemLogStorage)(nil)
|
||||
var _ logstore.Queries = (*InmemLogStorage)(nil)
|
||||
|
||||
type InmemLogStorage struct {
|
||||
mux sync.Mutex
|
||||
clock clock.Clock
|
||||
emitted []*Record
|
||||
bulks []int
|
||||
quota *query.Quota
|
||||
}
|
||||
|
||||
func NewInMemoryStorage(clock clock.Clock, quota *query.Quota) *InmemLogStorage {
|
||||
return &InmemLogStorage{
|
||||
clock: clock,
|
||||
emitted: make([]*Record, 0),
|
||||
bulks: make([]int, 0),
|
||||
quota: quota,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *InmemLogStorage) QuotaUnit() quota.Unit {
|
||||
return quota.Unimplemented
|
||||
}
|
||||
|
||||
func (l *InmemLogStorage) Emit(_ context.Context, bulk []*Record) error {
|
||||
if len(bulk) == 0 {
|
||||
return nil
|
||||
}
|
||||
l.mux.Lock()
|
||||
defer l.mux.Unlock()
|
||||
l.emitted = append(l.emitted, bulk...)
|
||||
l.bulks = append(l.bulks, len(bulk))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *InmemLogStorage) QueryUsage(_ context.Context, _ string, start time.Time) (uint64, error) {
|
||||
l.mux.Lock()
|
||||
defer l.mux.Unlock()
|
||||
|
||||
var count uint64
|
||||
for _, r := range l.emitted {
|
||||
if r.ts.After(start) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (l *InmemLogStorage) Cleanup(_ context.Context, keep time.Duration) error {
|
||||
l.mux.Lock()
|
||||
defer l.mux.Unlock()
|
||||
|
||||
clean := make([]*Record, 0)
|
||||
from := l.clock.Now().Add(-(keep + 1))
|
||||
for _, r := range l.emitted {
|
||||
if r.ts.After(from) {
|
||||
clean = append(clean, r)
|
||||
}
|
||||
}
|
||||
l.emitted = clean
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *InmemLogStorage) Bulks() []int {
|
||||
l.mux.Lock()
|
||||
defer l.mux.Unlock()
|
||||
|
||||
return l.bulks
|
||||
}
|
||||
|
||||
func (l *InmemLogStorage) Len() int {
|
||||
l.mux.Lock()
|
||||
defer l.mux.Unlock()
|
||||
|
||||
return len(l.emitted)
|
||||
}
|
||||
|
||||
func (l *InmemLogStorage) GetQuota(ctx context.Context, instanceID string, unit quota.Unit) (qu *query.Quota, err error) {
|
||||
return l.quota, nil
|
||||
}
|
||||
|
||||
func (l *InmemLogStorage) GetQuotaUsage(ctx context.Context, instanceID string, unit quota.Unit, periodStart time.Time) (usage uint64, err error) {
|
||||
return uint64(l.Len()), nil
|
||||
}
|
||||
|
||||
func (l *InmemLogStorage) GetRemainingQuotaUsage(ctx context.Context, instanceID string, unit quota.Unit) (remaining *uint64, err error) {
|
||||
if !l.quota.Limit {
|
||||
return nil, nil
|
||||
}
|
||||
var r uint64
|
||||
used := uint64(l.Len())
|
||||
if used > l.quota.Amount {
|
||||
return &r, nil
|
||||
}
|
||||
r = l.quota.Amount - used
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
func (l *InmemLogStorage) GetDueQuotaNotifications(ctx context.Context, instanceID string, unit quota.Unit, qu *query.Quota, periodStart time.Time, usedAbs uint64) (dueNotifications []*quota.NotificationDueEvent, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (l *InmemLogStorage) ReportQuotaUsage(ctx context.Context, dueNotifications []*quota.NotificationDueEvent) error {
|
||||
return nil
|
||||
}
|
@ -8,18 +8,18 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
)
|
||||
|
||||
var _ logstore.LogRecord = (*record)(nil)
|
||||
var _ logstore.LogRecord[*Record] = (*Record)(nil)
|
||||
|
||||
func NewRecord(clock clock.Clock) *record {
|
||||
return &record{ts: clock.Now()}
|
||||
func NewRecord(clock clock.Clock) *Record {
|
||||
return &Record{ts: clock.Now()}
|
||||
}
|
||||
|
||||
type record struct {
|
||||
type Record struct {
|
||||
ts time.Time
|
||||
redacted bool
|
||||
}
|
||||
|
||||
func (r record) Normalize() logstore.LogRecord {
|
||||
func (r Record) Normalize() *Record {
|
||||
r.redacted = true
|
||||
return &r
|
||||
}
|
@ -1,28 +0,0 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/quota"
|
||||
)
|
||||
|
||||
var _ logstore.QuotaQuerier = (*inmemReporter)(nil)
|
||||
|
||||
type inmemReporter struct {
|
||||
config *quota.AddedEvent
|
||||
startPeriod time.Time
|
||||
}
|
||||
|
||||
func NewNoopQuerier(quota *quota.AddedEvent, startPeriod time.Time) *inmemReporter {
|
||||
return &inmemReporter{config: quota, startPeriod: startPeriod}
|
||||
}
|
||||
|
||||
func (i *inmemReporter) GetCurrentQuotaPeriod(context.Context, string, quota.Unit) (*quota.AddedEvent, time.Time, error) {
|
||||
return i.config, i.startPeriod, nil
|
||||
}
|
||||
|
||||
func (*inmemReporter) GetDueQuotaNotifications(context.Context, *quota.AddedEvent, time.Time, uint64) ([]*quota.NotificationDueEvent, error) {
|
||||
return nil, nil
|
||||
}
|
135
internal/logstore/record/access.go
Normal file
135
internal/logstore/record/access.go
Normal file
@ -0,0 +1,135 @@
|
||||
package record
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
|
||||
zitadel_http "github.com/zitadel/zitadel/internal/api/http"
|
||||
)
|
||||
|
||||
type AccessLog struct {
|
||||
LogDate time.Time `json:"logDate"`
|
||||
Protocol AccessProtocol `json:"protocol"`
|
||||
RequestURL string `json:"requestUrl"`
|
||||
ResponseStatus uint32 `json:"responseStatus"`
|
||||
// RequestHeaders and ResponseHeaders are plain maps so varying implementations
|
||||
// between HTTP and gRPC don't interfere with each other
|
||||
RequestHeaders map[string][]string `json:"requestHeaders"`
|
||||
ResponseHeaders map[string][]string `json:"responseHeaders"`
|
||||
InstanceID string `json:"instanceId"`
|
||||
ProjectID string `json:"projectId"`
|
||||
RequestedDomain string `json:"requestedDomain"`
|
||||
RequestedHost string `json:"requestedHost"`
|
||||
// NotCountable can be used by the logging service to explicitly stating,
|
||||
// that the request must not increase the amount of countable (authenticated) requests
|
||||
NotCountable bool `json:"-"`
|
||||
normalized bool `json:"-"`
|
||||
}
|
||||
|
||||
type AccessProtocol uint8
|
||||
|
||||
const (
|
||||
GRPC AccessProtocol = iota
|
||||
HTTP
|
||||
|
||||
redacted = "[REDACTED]"
|
||||
)
|
||||
|
||||
var (
|
||||
unaccountableEndpoints = []string{
|
||||
"/zitadel.system.v1.SystemService/",
|
||||
"/zitadel.admin.v1.AdminService/Healthz",
|
||||
"/zitadel.management.v1.ManagementService/Healthz",
|
||||
"/zitadel.management.v1.ManagementService/GetOIDCInformation",
|
||||
"/zitadel.auth.v1.AuthService/Healthz",
|
||||
}
|
||||
)
|
||||
|
||||
func (a AccessLog) IsAuthenticated() bool {
|
||||
if a.NotCountable {
|
||||
return false
|
||||
}
|
||||
if !a.normalized {
|
||||
panic("access log not normalized, Normalize() must be called before IsAuthenticated()")
|
||||
}
|
||||
_, hasHTTPAuthHeader := a.RequestHeaders[strings.ToLower(zitadel_http.Authorization)]
|
||||
// ignore requests, which were unauthorized or do not require an authorization (even if one was sent)
|
||||
// also ignore if the limit was already reached or if the server returned an internal error
|
||||
// not that endpoints paths are only checked with the gRPC representation as HTTP (gateway) will not log them
|
||||
return hasHTTPAuthHeader &&
|
||||
(a.Protocol == HTTP &&
|
||||
a.ResponseStatus != http.StatusInternalServerError &&
|
||||
a.ResponseStatus != http.StatusTooManyRequests &&
|
||||
a.ResponseStatus != http.StatusUnauthorized) ||
|
||||
(a.Protocol == GRPC &&
|
||||
a.ResponseStatus != uint32(codes.Internal) &&
|
||||
a.ResponseStatus != uint32(codes.ResourceExhausted) &&
|
||||
a.ResponseStatus != uint32(codes.Unauthenticated) &&
|
||||
!a.isUnaccountableEndpoint())
|
||||
}
|
||||
|
||||
func (a AccessLog) isUnaccountableEndpoint() bool {
|
||||
for _, endpoint := range unaccountableEndpoints {
|
||||
if strings.HasPrefix(a.RequestURL, endpoint) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a AccessLog) Normalize() *AccessLog {
|
||||
a.RequestedDomain = cutString(a.RequestedDomain, 200)
|
||||
a.RequestURL = cutString(a.RequestURL, 200)
|
||||
a.RequestHeaders = normalizeHeaders(a.RequestHeaders, strings.ToLower(zitadel_http.Authorization), "grpcgateway-authorization", "cookie", "grpcgateway-cookie")
|
||||
a.ResponseHeaders = normalizeHeaders(a.ResponseHeaders, "set-cookie")
|
||||
a.normalized = true
|
||||
return &a
|
||||
}
|
||||
|
||||
// normalizeHeaders lowers all header keys and redacts secrets
|
||||
func normalizeHeaders(header map[string][]string, redactKeysLower ...string) map[string][]string {
|
||||
return pruneKeys(redactKeys(lowerKeys(header), redactKeysLower...))
|
||||
}
|
||||
|
||||
func lowerKeys(header map[string][]string) map[string][]string {
|
||||
lower := make(map[string][]string, len(header))
|
||||
for k, v := range header {
|
||||
lower[strings.ToLower(k)] = v
|
||||
}
|
||||
return lower
|
||||
}
|
||||
|
||||
func redactKeys(header map[string][]string, redactKeysLower ...string) map[string][]string {
|
||||
redactedKeys := make(map[string][]string, len(header))
|
||||
for k, v := range header {
|
||||
redactedKeys[k] = v
|
||||
}
|
||||
for _, redactKey := range redactKeysLower {
|
||||
if _, ok := redactedKeys[redactKey]; ok {
|
||||
redactedKeys[redactKey] = []string{redacted}
|
||||
}
|
||||
}
|
||||
return redactedKeys
|
||||
}
|
||||
|
||||
const maxValuesPerKey = 10
|
||||
|
||||
func pruneKeys(header map[string][]string) map[string][]string {
|
||||
prunedKeys := make(map[string][]string, len(header))
|
||||
for key, value := range header {
|
||||
valueItems := make([]string, 0, maxValuesPerKey)
|
||||
for i, valueItem := range value {
|
||||
// Max 10 header values per key
|
||||
if i > maxValuesPerKey {
|
||||
break
|
||||
}
|
||||
// Max 200 value length
|
||||
valueItems = append(valueItems, cutString(valueItem, 200))
|
||||
}
|
||||
prunedKeys[key] = valueItems
|
||||
}
|
||||
return prunedKeys
|
||||
}
|
@ -1,20 +1,18 @@
|
||||
package access_test
|
||||
package record
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/logstore/emitters/access"
|
||||
)
|
||||
|
||||
func TestRecord_Normalize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
record access.Record
|
||||
want *access.Record
|
||||
record AccessLog
|
||||
want *AccessLog
|
||||
}{{
|
||||
name: "headers with certain keys should be redacted",
|
||||
record: access.Record{
|
||||
record: AccessLog{
|
||||
RequestHeaders: map[string][]string{
|
||||
"authorization": {"AValue"},
|
||||
"grpcgateway-authorization": {"AValue"},
|
||||
@ -24,7 +22,7 @@ func TestRecord_Normalize(t *testing.T) {
|
||||
"set-cookie": {"AValue"},
|
||||
},
|
||||
},
|
||||
want: &access.Record{
|
||||
want: &AccessLog{
|
||||
RequestHeaders: map[string][]string{
|
||||
"authorization": {"[REDACTED]"},
|
||||
"grpcgateway-authorization": {"[REDACTED]"},
|
||||
@ -36,22 +34,22 @@ func TestRecord_Normalize(t *testing.T) {
|
||||
},
|
||||
}, {
|
||||
name: "header keys should be lower cased",
|
||||
record: access.Record{
|
||||
record: AccessLog{
|
||||
RequestHeaders: map[string][]string{"AKey": {"AValue"}},
|
||||
ResponseHeaders: map[string][]string{"AKey": {"AValue"}}},
|
||||
want: &access.Record{
|
||||
want: &AccessLog{
|
||||
RequestHeaders: map[string][]string{"akey": {"AValue"}},
|
||||
ResponseHeaders: map[string][]string{"akey": {"AValue"}}},
|
||||
}, {
|
||||
name: "an already prune record should stay unchanged",
|
||||
record: access.Record{
|
||||
record: AccessLog{
|
||||
RequestURL: "https://my.zitadel.cloud/",
|
||||
RequestHeaders: map[string][]string{
|
||||
"authorization": {"[REDACTED]"},
|
||||
},
|
||||
ResponseHeaders: map[string][]string{},
|
||||
},
|
||||
want: &access.Record{
|
||||
want: &AccessLog{
|
||||
RequestURL: "https://my.zitadel.cloud/",
|
||||
RequestHeaders: map[string][]string{
|
||||
"authorization": {"[REDACTED]"},
|
||||
@ -60,17 +58,18 @@ func TestRecord_Normalize(t *testing.T) {
|
||||
},
|
||||
}, {
|
||||
name: "empty record should stay empty",
|
||||
record: access.Record{
|
||||
record: AccessLog{
|
||||
RequestHeaders: map[string][]string{},
|
||||
ResponseHeaders: map[string][]string{},
|
||||
},
|
||||
want: &access.Record{
|
||||
want: &AccessLog{
|
||||
RequestHeaders: map[string][]string{},
|
||||
ResponseHeaders: map[string][]string{},
|
||||
},
|
||||
}}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.want.normalized = true
|
||||
if got := tt.record.Normalize(); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Normalize() = %v, want %v", got, tt.want)
|
||||
}
|
@ -1,16 +1,12 @@
|
||||
package execution
|
||||
package record
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
)
|
||||
|
||||
var _ logstore.LogRecord = (*Record)(nil)
|
||||
|
||||
type Record struct {
|
||||
type ExecutionLog struct {
|
||||
LogDate time.Time `json:"logDate"`
|
||||
Took time.Duration `json:"took"`
|
||||
Message string `json:"message"`
|
||||
@ -20,14 +16,7 @@ type Record struct {
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
func (e Record) Normalize() logstore.LogRecord {
|
||||
func (e ExecutionLog) Normalize() *ExecutionLog {
|
||||
e.Message = cutString(e.Message, 2000)
|
||||
return &e
|
||||
}
|
||||
|
||||
func cutString(str string, pos int) string {
|
||||
if len(str) <= pos {
|
||||
return str
|
||||
}
|
||||
return str[:pos]
|
||||
}
|
8
internal/logstore/record/prune.go
Normal file
8
internal/logstore/record/prune.go
Normal file
@ -0,0 +1,8 @@
|
||||
package record
|
||||
|
||||
func cutString(str string, pos int) string {
|
||||
if len(str) <= pos {
|
||||
return str
|
||||
}
|
||||
return str[:pos-1]
|
||||
}
|
@ -2,124 +2,70 @@ package logstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/repository/quota"
|
||||
)
|
||||
|
||||
const handleThresholdTimeout = time.Minute
|
||||
|
||||
type QuotaQuerier interface {
|
||||
GetCurrentQuotaPeriod(ctx context.Context, instanceID string, unit quota.Unit) (config *quota.AddedEvent, periodStart time.Time, err error)
|
||||
GetDueQuotaNotifications(ctx context.Context, config *quota.AddedEvent, periodStart time.Time, used uint64) ([]*quota.NotificationDueEvent, error)
|
||||
}
|
||||
|
||||
type UsageQuerier interface {
|
||||
LogEmitter
|
||||
type UsageStorer[T LogRecord[T]] interface {
|
||||
LogEmitter[T]
|
||||
QuotaUnit() quota.Unit
|
||||
QueryUsage(ctx context.Context, instanceId string, start time.Time) (uint64, error)
|
||||
}
|
||||
|
||||
type UsageReporter interface {
|
||||
Report(ctx context.Context, notifications []*quota.NotificationDueEvent) (err error)
|
||||
}
|
||||
|
||||
type UsageReporterFunc func(context.Context, []*quota.NotificationDueEvent) (err error)
|
||||
|
||||
func (u UsageReporterFunc) Report(ctx context.Context, notifications []*quota.NotificationDueEvent) (err error) {
|
||||
return u(ctx, notifications)
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
usageQuerier UsageQuerier
|
||||
quotaQuerier QuotaQuerier
|
||||
usageReporter UsageReporter
|
||||
enabledSinks []*emitter
|
||||
type Service[T LogRecord[T]] struct {
|
||||
queries Queries
|
||||
usageStorer UsageStorer[T]
|
||||
enabledSinks []*emitter[T]
|
||||
sinkEnabled bool
|
||||
reportingEnabled bool
|
||||
}
|
||||
|
||||
func New(quotaQuerier QuotaQuerier, usageReporter UsageReporter, usageQuerierSink *emitter, additionalSink ...*emitter) *Service {
|
||||
var usageQuerier UsageQuerier
|
||||
type Queries interface {
|
||||
GetRemainingQuotaUsage(ctx context.Context, instanceID string, unit quota.Unit) (remaining *uint64, err error)
|
||||
}
|
||||
|
||||
func New[T LogRecord[T]](queries Queries, usageQuerierSink *emitter[T], additionalSink ...*emitter[T]) *Service[T] {
|
||||
var usageStorer UsageStorer[T]
|
||||
if usageQuerierSink != nil {
|
||||
usageQuerier = usageQuerierSink.emitter.(UsageQuerier)
|
||||
usageStorer = usageQuerierSink.emitter.(UsageStorer[T])
|
||||
}
|
||||
|
||||
svc := &Service{
|
||||
svc := &Service[T]{
|
||||
queries: queries,
|
||||
reportingEnabled: usageQuerierSink != nil && usageQuerierSink.enabled,
|
||||
usageQuerier: usageQuerier,
|
||||
quotaQuerier: quotaQuerier,
|
||||
usageReporter: usageReporter,
|
||||
usageStorer: usageStorer,
|
||||
}
|
||||
|
||||
for _, s := range append([]*emitter{usageQuerierSink}, additionalSink...) {
|
||||
for _, s := range append([]*emitter[T]{usageQuerierSink}, additionalSink...) {
|
||||
if s != nil && s.enabled {
|
||||
svc.enabledSinks = append(svc.enabledSinks, s)
|
||||
}
|
||||
}
|
||||
|
||||
svc.sinkEnabled = len(svc.enabledSinks) > 0
|
||||
|
||||
return svc
|
||||
}
|
||||
|
||||
func (s *Service) Enabled() bool {
|
||||
func (s *Service[T]) Enabled() bool {
|
||||
return s.sinkEnabled
|
||||
}
|
||||
|
||||
func (s *Service) Handle(ctx context.Context, record LogRecord) {
|
||||
func (s *Service[T]) Handle(ctx context.Context, record T) {
|
||||
for _, sink := range s.enabledSinks {
|
||||
logging.OnError(sink.Emit(ctx, record.Normalize())).WithField("record", record).Warn("failed to emit log record")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Limit(ctx context.Context, instanceID string) *uint64 {
|
||||
func (s *Service[T]) Limit(ctx context.Context, instanceID string) *uint64 {
|
||||
var err error
|
||||
defer func() {
|
||||
logging.OnError(err).Warn("failed to check is usage should be limited")
|
||||
logging.OnError(err).Warn("failed to check if usage should be limited")
|
||||
}()
|
||||
|
||||
if !s.reportingEnabled || instanceID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
quota, periodStart, err := s.quotaQuerier.GetCurrentQuotaPeriod(ctx, instanceID, s.usageQuerier.QuotaUnit())
|
||||
if err != nil || quota == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
usage, err := s.usageQuerier.QueryUsage(ctx, instanceID, periodStart)
|
||||
remaining, err := s.queries.GetRemainingQuotaUsage(ctx, instanceID, s.usageStorer.QuotaUnit())
|
||||
if err != nil {
|
||||
// TODO: shouldn't we just limit then or return the error and decide there?
|
||||
return nil
|
||||
}
|
||||
|
||||
go s.handleThresholds(ctx, quota, periodStart, usage)
|
||||
|
||||
var remaining *uint64
|
||||
if quota.Limit {
|
||||
r := uint64(math.Max(0, float64(quota.Amount)-float64(usage)))
|
||||
remaining = &r
|
||||
}
|
||||
return remaining
|
||||
}
|
||||
|
||||
func (s *Service) handleThresholds(ctx context.Context, quota *quota.AddedEvent, periodStart time.Time, usage uint64) {
|
||||
var err error
|
||||
defer func() {
|
||||
logging.OnError(err).Warn("handling quota thresholds failed")
|
||||
}()
|
||||
|
||||
detatchedCtx, cancel := context.WithTimeout(authz.Detach(ctx), handleThresholdTimeout)
|
||||
defer cancel()
|
||||
|
||||
notifications, err := s.quotaQuerier.GetDueQuotaNotifications(detatchedCtx, quota, periodStart, usage)
|
||||
if err != nil || len(notifications) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
err = s.usageReporter.Report(detatchedCtx, notifications)
|
||||
}
|
||||
|
@ -14,9 +14,8 @@ import (
|
||||
"github.com/benbjohnson/clock"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
emittermock "github.com/zitadel/zitadel/internal/logstore/emitters/mock"
|
||||
quotaqueriermock "github.com/zitadel/zitadel/internal/logstore/quotaqueriers/mock"
|
||||
"github.com/zitadel/zitadel/internal/repository/quota"
|
||||
emittermock "github.com/zitadel/zitadel/internal/logstore/mock"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -27,7 +26,7 @@ const (
|
||||
type args struct {
|
||||
mainSink *logstore.EmitterConfig
|
||||
secondarySink *logstore.EmitterConfig
|
||||
config quota.AddedEvent
|
||||
config *query.Quota
|
||||
}
|
||||
|
||||
type want struct {
|
||||
@ -137,28 +136,6 @@ func TestService(t *testing.T) {
|
||||
len: 0,
|
||||
},
|
||||
},
|
||||
}, {
|
||||
name: "cleanupping works",
|
||||
args: args{
|
||||
mainSink: emitterConfig(withCleanupping(17*time.Second, 28*time.Second)),
|
||||
secondarySink: emitterConfig(withDebouncerConfig(&logstore.DebouncerConfig{
|
||||
MinFrequency: 0,
|
||||
MaxBulkSize: 15,
|
||||
}), withCleanupping(5*time.Second, 47*time.Second)),
|
||||
config: quotaConfig(),
|
||||
},
|
||||
want: want{
|
||||
enabled: true,
|
||||
remaining: nil,
|
||||
mainSink: wantSink{
|
||||
bulks: repeat(1, 60),
|
||||
len: 21,
|
||||
},
|
||||
secondarySink: wantSink{
|
||||
bulks: repeat(15, 4),
|
||||
len: 18,
|
||||
},
|
||||
},
|
||||
}, {
|
||||
name: "when quota has a limit of 90, 30 are remaining",
|
||||
args: args{
|
||||
@ -232,27 +209,24 @@ func runTest(t *testing.T, name string, args args, want want) bool {
|
||||
})
|
||||
}
|
||||
|
||||
func given(t *testing.T, args args, want want) (context.Context, *clock.Mock, *emittermock.InmemLogStorage, *emittermock.InmemLogStorage, *logstore.Service) {
|
||||
func given(t *testing.T, args args, want want) (context.Context, *clock.Mock, *emittermock.InmemLogStorage, *emittermock.InmemLogStorage, *logstore.Service[*emittermock.Record]) {
|
||||
ctx := context.Background()
|
||||
clock := clock.NewMock()
|
||||
|
||||
periodStart := time.Time{}
|
||||
clock.Set(args.config.From)
|
||||
|
||||
mainStorage := emittermock.NewInMemoryStorage(clock)
|
||||
mainEmitter, err := logstore.NewEmitter(ctx, clock, args.mainSink, mainStorage)
|
||||
mainStorage := emittermock.NewInMemoryStorage(clock, args.config)
|
||||
mainEmitter, err := logstore.NewEmitter[*emittermock.Record](ctx, clock, args.mainSink, mainStorage)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error but got %v", err)
|
||||
}
|
||||
secondaryStorage := emittermock.NewInMemoryStorage(clock)
|
||||
secondaryEmitter, err := logstore.NewEmitter(ctx, clock, args.secondarySink, secondaryStorage)
|
||||
secondaryStorage := emittermock.NewInMemoryStorage(clock, args.config)
|
||||
secondaryEmitter, err := logstore.NewEmitter[*emittermock.Record](ctx, clock, args.secondarySink, secondaryStorage)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error but got %v", err)
|
||||
}
|
||||
|
||||
svc := logstore.New(
|
||||
quotaqueriermock.NewNoopQuerier(&args.config, periodStart),
|
||||
logstore.UsageReporterFunc(func(context.Context, []*quota.NotificationDueEvent) error { return nil }),
|
||||
svc := logstore.New[*emittermock.Record](
|
||||
mainStorage,
|
||||
mainEmitter,
|
||||
secondaryEmitter)
|
||||
|
||||
@ -262,7 +236,7 @@ func given(t *testing.T, args args, want want) (context.Context, *clock.Mock, *e
|
||||
return ctx, clock, mainStorage, secondaryStorage, svc
|
||||
}
|
||||
|
||||
func when(svc *logstore.Service, ctx context.Context, clock *clock.Mock) *uint64 {
|
||||
func when(svc *logstore.Service[*emittermock.Record], ctx context.Context, clock *clock.Mock) *uint64 {
|
||||
var remaining *uint64
|
||||
for i := 0; i < ticks; i++ {
|
||||
svc.Handle(ctx, emittermock.NewRecord(clock))
|
||||
|
@ -1,11 +1,11 @@
|
||||
//go:build !integration
|
||||
|
||||
package start
|
||||
package net
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
func listenConfig() *net.ListenConfig {
|
||||
func ListenConfig() *net.ListenConfig {
|
||||
return &net.ListenConfig{}
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
//go:build integration
|
||||
|
||||
package start
|
||||
package net
|
||||
|
||||
import (
|
||||
"net"
|
||||
@ -9,7 +9,7 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func listenConfig() *net.ListenConfig {
|
||||
func ListenConfig() *net.ListenConfig {
|
||||
return &net.ListenConfig{
|
||||
Control: reusePort,
|
||||
}
|
@ -20,11 +20,13 @@ var (
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(func() int {
|
||||
ctx, _, cancel := integration.Contexts(5 * time.Minute)
|
||||
CTX = ctx
|
||||
defer cancel()
|
||||
CTX = ctx
|
||||
|
||||
Tester = integration.NewTester(ctx)
|
||||
SystemCTX = Tester.WithAuthorization(ctx, integration.SystemUser)
|
||||
defer Tester.Done()
|
||||
|
||||
SystemCTX = Tester.WithAuthorization(ctx, integration.SystemUser)
|
||||
return m.Run()
|
||||
}())
|
||||
}
|
||||
|
@ -5,11 +5,6 @@ package handlers_test
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -18,49 +13,27 @@ import (
|
||||
)
|
||||
|
||||
func TestServer_TelemetryPushMilestones(t *testing.T) {
|
||||
bodies := make(chan []byte, 0)
|
||||
mockServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if r.Header.Get("single-value") != "single-value" {
|
||||
t.Error("single-value header not set")
|
||||
}
|
||||
if reflect.DeepEqual(r.Header.Get("multi-value"), "multi-value-1,multi-value-2") {
|
||||
t.Error("single-value header not set")
|
||||
}
|
||||
bodies <- body
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
listener, err := net.Listen("tcp", "localhost:8081")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mockServer.Listener = listener
|
||||
mockServer.Start()
|
||||
t.Cleanup(mockServer.Close)
|
||||
primaryDomain, instanceID, iamOwnerCtx := Tester.UseIsolatedInstance(CTX, SystemCTX)
|
||||
t.Log("testing against instance with primary domain", primaryDomain)
|
||||
awaitMilestone(t, bodies, primaryDomain, "InstanceCreated")
|
||||
awaitMilestone(t, Tester.MilestoneChan, primaryDomain, "InstanceCreated")
|
||||
project, err := Tester.Client.Mgmt.AddProject(iamOwnerCtx, &management.AddProjectRequest{Name: "integration"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
awaitMilestone(t, bodies, primaryDomain, "ProjectCreated")
|
||||
awaitMilestone(t, Tester.MilestoneChan, primaryDomain, "ProjectCreated")
|
||||
if _, err = Tester.Client.Mgmt.AddOIDCApp(iamOwnerCtx, &management.AddOIDCAppRequest{
|
||||
ProjectId: project.GetId(),
|
||||
Name: "integration",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
awaitMilestone(t, bodies, primaryDomain, "ApplicationCreated")
|
||||
awaitMilestone(t, Tester.MilestoneChan, primaryDomain, "ApplicationCreated")
|
||||
// TODO: trigger and await milestone AuthenticationSucceededOnInstance
|
||||
// TODO: trigger and await milestone AuthenticationSucceededOnApplication
|
||||
if _, err = Tester.Client.System.RemoveInstance(SystemCTX, &system.RemoveInstanceRequest{InstanceId: instanceID}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
awaitMilestone(t, bodies, primaryDomain, "InstanceDeleted")
|
||||
awaitMilestone(t, Tester.MilestoneChan, primaryDomain, "InstanceDeleted")
|
||||
}
|
||||
|
||||
func awaitMilestone(t *testing.T, bodies chan []byte, primaryDomain, expectMilestoneType string) {
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
key_repo "github.com/zitadel/zitadel/internal/repository/keypair"
|
||||
"github.com/zitadel/zitadel/internal/repository/org"
|
||||
proj_repo "github.com/zitadel/zitadel/internal/repository/project"
|
||||
quota_repo "github.com/zitadel/zitadel/internal/repository/quota"
|
||||
usr_repo "github.com/zitadel/zitadel/internal/repository/user"
|
||||
"github.com/zitadel/zitadel/internal/repository/usergrant"
|
||||
)
|
||||
@ -29,6 +30,7 @@ func eventstoreExpect(t *testing.T, expects ...expect) *eventstore.Eventstore {
|
||||
org.RegisterEventMappers(es)
|
||||
usr_repo.RegisterEventMappers(es)
|
||||
proj_repo.RegisterEventMappers(es)
|
||||
quota_repo.RegisterEventMappers(es)
|
||||
usergrant.RegisterEventMappers(es)
|
||||
key_repo.RegisterEventMappers(es)
|
||||
action_repo.RegisterEventMappers(es)
|
||||
|
@ -69,6 +69,7 @@ var (
|
||||
SessionProjection *sessionProjection
|
||||
AuthRequestProjection *authRequestProjection
|
||||
MilestoneProjection *milestoneProjection
|
||||
QuotaProjection *quotaProjection
|
||||
)
|
||||
|
||||
type projection interface {
|
||||
@ -148,6 +149,7 @@ func Create(ctx context.Context, sqlClient *database.DB, es *eventstore.Eventsto
|
||||
SessionProjection = newSessionProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["sessions"]))
|
||||
AuthRequestProjection = newAuthRequestProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["auth_requests"]))
|
||||
MilestoneProjection = newMilestoneProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["milestones"]))
|
||||
QuotaProjection = newQuotaProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["quotas"]))
|
||||
newProjectionsList()
|
||||
return nil
|
||||
}
|
||||
@ -247,5 +249,6 @@ func newProjectionsList() {
|
||||
SessionProjection,
|
||||
AuthRequestProjection,
|
||||
MilestoneProjection,
|
||||
QuotaProjection,
|
||||
}
|
||||
}
|
||||
|
285
internal/query/projection/quota.go
Normal file
285
internal/query/projection/quota.go
Normal file
@ -0,0 +1,285 @@
|
||||
package projection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler/crdb"
|
||||
"github.com/zitadel/zitadel/internal/repository/instance"
|
||||
"github.com/zitadel/zitadel/internal/repository/quota"
|
||||
)
|
||||
|
||||
const (
|
||||
QuotasProjectionTable = "projections.quotas"
|
||||
QuotaPeriodsProjectionTable = QuotasProjectionTable + "_" + quotaPeriodsTableSuffix
|
||||
QuotaNotificationsTable = QuotasProjectionTable + "_" + quotaNotificationsTableSuffix
|
||||
|
||||
QuotaColumnID = "id"
|
||||
QuotaColumnInstanceID = "instance_id"
|
||||
QuotaColumnUnit = "unit"
|
||||
QuotaColumnAmount = "amount"
|
||||
QuotaColumnFrom = "from_anchor"
|
||||
QuotaColumnInterval = "interval"
|
||||
QuotaColumnLimit = "limit_usage"
|
||||
|
||||
quotaPeriodsTableSuffix = "periods"
|
||||
QuotaPeriodColumnInstanceID = "instance_id"
|
||||
QuotaPeriodColumnUnit = "unit"
|
||||
QuotaPeriodColumnStart = "start"
|
||||
QuotaPeriodColumnUsage = "usage"
|
||||
|
||||
quotaNotificationsTableSuffix = "notifications"
|
||||
QuotaNotificationColumnInstanceID = "instance_id"
|
||||
QuotaNotificationColumnUnit = "unit"
|
||||
QuotaNotificationColumnID = "id"
|
||||
QuotaNotificationColumnCallURL = "call_url"
|
||||
QuotaNotificationColumnPercent = "percent"
|
||||
QuotaNotificationColumnRepeat = "repeat"
|
||||
QuotaNotificationColumnLatestDuePeriodStart = "latest_due_period_start"
|
||||
QuotaNotificationColumnNextDueThreshold = "next_due_threshold"
|
||||
)
|
||||
|
||||
const (
|
||||
incrementQuotaStatement = `INSERT INTO projections.quotas_periods` +
|
||||
` (instance_id, unit, start, usage)` +
|
||||
` VALUES ($1, $2, $3, $4) ON CONFLICT (instance_id, unit, start)` +
|
||||
` DO UPDATE SET usage = projections.quotas_periods.usage + excluded.usage RETURNING usage`
|
||||
)
|
||||
|
||||
type quotaProjection struct {
|
||||
crdb.StatementHandler
|
||||
client *database.DB
|
||||
}
|
||||
|
||||
func newQuotaProjection(ctx context.Context, config crdb.StatementHandlerConfig) *quotaProjection {
|
||||
p := new(quotaProjection)
|
||||
config.ProjectionName = QuotasProjectionTable
|
||||
config.Reducers = p.reducers()
|
||||
config.InitCheck = crdb.NewMultiTableCheck(
|
||||
crdb.NewTable(
|
||||
[]*crdb.Column{
|
||||
crdb.NewColumn(QuotaColumnID, crdb.ColumnTypeText),
|
||||
crdb.NewColumn(QuotaColumnInstanceID, crdb.ColumnTypeText),
|
||||
crdb.NewColumn(QuotaColumnUnit, crdb.ColumnTypeEnum),
|
||||
crdb.NewColumn(QuotaColumnAmount, crdb.ColumnTypeInt64),
|
||||
crdb.NewColumn(QuotaColumnFrom, crdb.ColumnTypeTimestamp),
|
||||
crdb.NewColumn(QuotaColumnInterval, crdb.ColumnTypeInterval),
|
||||
crdb.NewColumn(QuotaColumnLimit, crdb.ColumnTypeBool),
|
||||
},
|
||||
crdb.NewPrimaryKey(QuotaColumnInstanceID, QuotaColumnUnit),
|
||||
),
|
||||
crdb.NewSuffixedTable(
|
||||
[]*crdb.Column{
|
||||
crdb.NewColumn(QuotaPeriodColumnInstanceID, crdb.ColumnTypeText),
|
||||
crdb.NewColumn(QuotaPeriodColumnUnit, crdb.ColumnTypeEnum),
|
||||
crdb.NewColumn(QuotaPeriodColumnStart, crdb.ColumnTypeTimestamp),
|
||||
crdb.NewColumn(QuotaPeriodColumnUsage, crdb.ColumnTypeInt64),
|
||||
},
|
||||
crdb.NewPrimaryKey(QuotaPeriodColumnInstanceID, QuotaPeriodColumnUnit, QuotaPeriodColumnStart),
|
||||
quotaPeriodsTableSuffix,
|
||||
),
|
||||
crdb.NewSuffixedTable(
|
||||
[]*crdb.Column{
|
||||
crdb.NewColumn(QuotaNotificationColumnInstanceID, crdb.ColumnTypeText),
|
||||
crdb.NewColumn(QuotaNotificationColumnUnit, crdb.ColumnTypeEnum),
|
||||
crdb.NewColumn(QuotaNotificationColumnID, crdb.ColumnTypeText),
|
||||
crdb.NewColumn(QuotaNotificationColumnCallURL, crdb.ColumnTypeText),
|
||||
crdb.NewColumn(QuotaNotificationColumnPercent, crdb.ColumnTypeInt64),
|
||||
crdb.NewColumn(QuotaNotificationColumnRepeat, crdb.ColumnTypeBool),
|
||||
crdb.NewColumn(QuotaNotificationColumnLatestDuePeriodStart, crdb.ColumnTypeTimestamp, crdb.Nullable()),
|
||||
crdb.NewColumn(QuotaNotificationColumnNextDueThreshold, crdb.ColumnTypeInt64, crdb.Nullable()),
|
||||
},
|
||||
crdb.NewPrimaryKey(QuotaNotificationColumnInstanceID, QuotaNotificationColumnUnit, QuotaNotificationColumnID),
|
||||
quotaNotificationsTableSuffix,
|
||||
),
|
||||
)
|
||||
p.StatementHandler = crdb.NewStatementHandler(ctx, config)
|
||||
p.client = config.Client
|
||||
return p
|
||||
}
|
||||
|
||||
func (q *quotaProjection) reducers() []handler.AggregateReducer {
|
||||
return []handler.AggregateReducer{
|
||||
{
|
||||
Aggregate: instance.AggregateType,
|
||||
EventRedusers: []handler.EventReducer{
|
||||
{
|
||||
Event: instance.InstanceRemovedEventType,
|
||||
Reduce: q.reduceInstanceRemoved,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Aggregate: quota.AggregateType,
|
||||
EventRedusers: []handler.EventReducer{
|
||||
{
|
||||
Event: quota.AddedEventType,
|
||||
Reduce: q.reduceQuotaAdded,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Aggregate: quota.AggregateType,
|
||||
EventRedusers: []handler.EventReducer{
|
||||
{
|
||||
Event: quota.RemovedEventType,
|
||||
Reduce: q.reduceQuotaRemoved,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Aggregate: quota.AggregateType,
|
||||
EventRedusers: []handler.EventReducer{
|
||||
{
|
||||
Event: quota.NotificationDueEventType,
|
||||
Reduce: q.reduceQuotaNotificationDue,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Aggregate: quota.AggregateType,
|
||||
EventRedusers: []handler.EventReducer{
|
||||
{
|
||||
Event: quota.NotifiedEventType,
|
||||
Reduce: q.reduceQuotaNotified,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (q *quotaProjection) reduceQuotaNotified(event eventstore.Event) (*handler.Statement, error) {
|
||||
return crdb.NewNoOpStatement(event), nil
|
||||
}
|
||||
|
||||
func (q *quotaProjection) reduceQuotaAdded(event eventstore.Event) (*handler.Statement, error) {
|
||||
e, err := assertEvent[*quota.AddedEvent](event)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
createStatements := make([]func(e eventstore.Event) crdb.Exec, len(e.Notifications)+1)
|
||||
createStatements[0] = crdb.AddCreateStatement(
|
||||
[]handler.Column{
|
||||
handler.NewCol(QuotaColumnID, e.Aggregate().ID),
|
||||
handler.NewCol(QuotaColumnInstanceID, e.Aggregate().InstanceID),
|
||||
handler.NewCol(QuotaColumnUnit, e.Unit),
|
||||
handler.NewCol(QuotaColumnAmount, e.Amount),
|
||||
handler.NewCol(QuotaColumnFrom, e.From),
|
||||
handler.NewCol(QuotaColumnInterval, e.ResetInterval),
|
||||
handler.NewCol(QuotaColumnLimit, e.Limit),
|
||||
})
|
||||
for i := range e.Notifications {
|
||||
notification := e.Notifications[i]
|
||||
createStatements[i+1] = crdb.AddCreateStatement(
|
||||
[]handler.Column{
|
||||
handler.NewCol(QuotaNotificationColumnInstanceID, e.Aggregate().InstanceID),
|
||||
handler.NewCol(QuotaNotificationColumnUnit, e.Unit),
|
||||
handler.NewCol(QuotaNotificationColumnID, notification.ID),
|
||||
handler.NewCol(QuotaNotificationColumnCallURL, notification.CallURL),
|
||||
handler.NewCol(QuotaNotificationColumnPercent, notification.Percent),
|
||||
handler.NewCol(QuotaNotificationColumnRepeat, notification.Repeat),
|
||||
},
|
||||
crdb.WithTableSuffix(quotaNotificationsTableSuffix),
|
||||
)
|
||||
}
|
||||
|
||||
return crdb.NewMultiStatement(e, createStatements...), nil
|
||||
}
|
||||
|
||||
func (q *quotaProjection) reduceQuotaNotificationDue(event eventstore.Event) (*handler.Statement, error) {
|
||||
e, err := assertEvent[*quota.NotificationDueEvent](event)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return crdb.NewUpdateStatement(e,
|
||||
[]handler.Column{
|
||||
handler.NewCol(QuotaNotificationColumnLatestDuePeriodStart, e.PeriodStart),
|
||||
handler.NewCol(QuotaNotificationColumnNextDueThreshold, e.Threshold+100), // next due_threshold is always the reached + 100 => percent (e.g. 90) in the next bucket (e.g. 190)
|
||||
},
|
||||
[]handler.Condition{
|
||||
handler.NewCond(QuotaNotificationColumnInstanceID, e.Aggregate().InstanceID),
|
||||
handler.NewCond(QuotaNotificationColumnUnit, e.Unit),
|
||||
handler.NewCond(QuotaNotificationColumnID, e.ID),
|
||||
},
|
||||
crdb.WithTableSuffix(quotaNotificationsTableSuffix),
|
||||
), nil
|
||||
}
|
||||
|
||||
func (q *quotaProjection) reduceQuotaRemoved(event eventstore.Event) (*handler.Statement, error) {
|
||||
e, err := assertEvent[*quota.RemovedEvent](event)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return crdb.NewMultiStatement(
|
||||
e,
|
||||
crdb.AddDeleteStatement(
|
||||
[]handler.Condition{
|
||||
handler.NewCond(QuotaPeriodColumnInstanceID, e.Aggregate().InstanceID),
|
||||
handler.NewCond(QuotaPeriodColumnUnit, e.Unit),
|
||||
},
|
||||
crdb.WithTableSuffix(quotaPeriodsTableSuffix),
|
||||
),
|
||||
crdb.AddDeleteStatement(
|
||||
[]handler.Condition{
|
||||
handler.NewCond(QuotaNotificationColumnInstanceID, e.Aggregate().InstanceID),
|
||||
handler.NewCond(QuotaNotificationColumnUnit, e.Unit),
|
||||
},
|
||||
crdb.WithTableSuffix(quotaNotificationsTableSuffix),
|
||||
),
|
||||
crdb.AddDeleteStatement(
|
||||
[]handler.Condition{
|
||||
handler.NewCond(QuotaColumnInstanceID, e.Aggregate().InstanceID),
|
||||
handler.NewCond(QuotaColumnUnit, e.Unit),
|
||||
},
|
||||
),
|
||||
), nil
|
||||
}
|
||||
|
||||
func (q *quotaProjection) reduceInstanceRemoved(event eventstore.Event) (*handler.Statement, error) {
|
||||
// we only assert the event to make sure it is the correct type
|
||||
e, err := assertEvent[*instance.InstanceRemovedEvent](event)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return crdb.NewMultiStatement(
|
||||
e,
|
||||
crdb.AddDeleteStatement(
|
||||
[]handler.Condition{
|
||||
handler.NewCond(QuotaPeriodColumnInstanceID, e.Aggregate().InstanceID),
|
||||
},
|
||||
crdb.WithTableSuffix(quotaPeriodsTableSuffix),
|
||||
),
|
||||
crdb.AddDeleteStatement(
|
||||
[]handler.Condition{
|
||||
handler.NewCond(QuotaNotificationColumnInstanceID, e.Aggregate().InstanceID),
|
||||
},
|
||||
crdb.WithTableSuffix(quotaNotificationsTableSuffix),
|
||||
),
|
||||
crdb.AddDeleteStatement(
|
||||
[]handler.Condition{
|
||||
handler.NewCond(QuotaColumnInstanceID, e.Aggregate().InstanceID),
|
||||
},
|
||||
),
|
||||
), nil
|
||||
}
|
||||
|
||||
func (q *quotaProjection) IncrementUsage(ctx context.Context, unit quota.Unit, instanceID string, periodStart time.Time, count uint64) (sum uint64, err error) {
|
||||
if count == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
err = q.client.DB.QueryRowContext(
|
||||
ctx,
|
||||
incrementQuotaStatement,
|
||||
instanceID, unit, periodStart, count,
|
||||
).Scan(&sum)
|
||||
if err != nil {
|
||||
return 0, errors.ThrowInternalf(err, "PROJ-SJL3h", "incrementing usage for unit %d failed for at least one quota period", unit)
|
||||
}
|
||||
return sum, err
|
||||
}
|
321
internal/query/projection/quota_test.go
Normal file
321
internal/query/projection/quota_test.go
Normal file
@ -0,0 +1,321 @@
|
||||
package projection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
"github.com/zitadel/zitadel/internal/repository/instance"
|
||||
"github.com/zitadel/zitadel/internal/repository/quota"
|
||||
)
|
||||
|
||||
func TestQuotasProjection_reduces(t *testing.T) {
|
||||
type args struct {
|
||||
event func(t *testing.T) eventstore.Event
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
reduce func(event eventstore.Event) (*handler.Statement, error)
|
||||
want wantReduce
|
||||
}{
|
||||
{
|
||||
name: "reduceQuotaAdded",
|
||||
args: args{
|
||||
event: getEvent(testEvent(
|
||||
repository.EventType(quota.AddedEventType),
|
||||
quota.AggregateType,
|
||||
[]byte(`{
|
||||
"unit": 1,
|
||||
"amount": 10,
|
||||
"limit": true,
|
||||
"from": "2023-01-01T00:00:00Z",
|
||||
"interval": 300000000000
|
||||
}`),
|
||||
), quota.AddedEventMapper),
|
||||
},
|
||||
reduce: ("aProjection{}).reduceQuotaAdded,
|
||||
want: wantReduce{
|
||||
aggregateType: eventstore.AggregateType("quota"),
|
||||
sequence: 15,
|
||||
previousSequence: 10,
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "INSERT INTO projections.quotas (id, instance_id, unit, amount, from_anchor, interval, limit_usage) VALUES ($1, $2, $3, $4, $5, $6, $7)",
|
||||
expectedArgs: []interface{}{
|
||||
"agg-id",
|
||||
"instance-id",
|
||||
quota.RequestsAllAuthenticated,
|
||||
uint64(10),
|
||||
time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
time.Minute * 5,
|
||||
true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "reduceQuotaAdded with notification",
|
||||
args: args{
|
||||
event: getEvent(testEvent(
|
||||
repository.EventType(quota.AddedEventType),
|
||||
quota.AggregateType,
|
||||
[]byte(`{
|
||||
"unit": 1,
|
||||
"amount": 10,
|
||||
"limit": true,
|
||||
"from": "2023-01-01T00:00:00Z",
|
||||
"interval": 300000000000,
|
||||
"notifications": [
|
||||
{
|
||||
"id": "id",
|
||||
"percent": 100,
|
||||
"repeat": true,
|
||||
"callURL": "url"
|
||||
}
|
||||
]
|
||||
}`),
|
||||
), quota.AddedEventMapper),
|
||||
},
|
||||
reduce: ("aProjection{}).reduceQuotaAdded,
|
||||
want: wantReduce{
|
||||
aggregateType: eventstore.AggregateType("quota"),
|
||||
sequence: 15,
|
||||
previousSequence: 10,
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "INSERT INTO projections.quotas (id, instance_id, unit, amount, from_anchor, interval, limit_usage) VALUES ($1, $2, $3, $4, $5, $6, $7)",
|
||||
expectedArgs: []interface{}{
|
||||
"agg-id",
|
||||
"instance-id",
|
||||
quota.RequestsAllAuthenticated,
|
||||
uint64(10),
|
||||
time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
time.Minute * 5,
|
||||
true,
|
||||
},
|
||||
},
|
||||
{
|
||||
expectedStmt: "INSERT INTO projections.quotas_notifications (instance_id, unit, id, call_url, percent, repeat) VALUES ($1, $2, $3, $4, $5, $6)",
|
||||
expectedArgs: []interface{}{
|
||||
"instance-id",
|
||||
quota.RequestsAllAuthenticated,
|
||||
"id",
|
||||
"url",
|
||||
uint16(100),
|
||||
true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "reduceQuotaNotificationDue",
|
||||
args: args{
|
||||
event: getEvent(testEvent(
|
||||
repository.EventType(quota.NotificationDueEventType),
|
||||
quota.AggregateType,
|
||||
[]byte(`{
|
||||
"id": "id",
|
||||
"unit": 1,
|
||||
"callURL": "url",
|
||||
"periodStart": "2023-01-01T00:00:00Z",
|
||||
"threshold": 200,
|
||||
"usage": 100
|
||||
}`),
|
||||
), quota.NotificationDueEventMapper),
|
||||
},
|
||||
reduce: ("aProjection{}).reduceQuotaNotificationDue,
|
||||
want: wantReduce{
|
||||
aggregateType: eventstore.AggregateType("quota"),
|
||||
sequence: 15,
|
||||
previousSequence: 10,
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "UPDATE projections.quotas_notifications SET (latest_due_period_start, next_due_threshold) = ($1, $2) WHERE (instance_id = $3) AND (unit = $4) AND (id = $5)",
|
||||
expectedArgs: []interface{}{
|
||||
time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
uint16(300),
|
||||
"instance-id",
|
||||
quota.RequestsAllAuthenticated,
|
||||
"id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "reduceQuotaRemoved",
|
||||
args: args{
|
||||
event: getEvent(testEvent(
|
||||
repository.EventType(quota.RemovedEventType),
|
||||
quota.AggregateType,
|
||||
[]byte(`{
|
||||
"unit": 1
|
||||
}`),
|
||||
), quota.RemovedEventMapper),
|
||||
},
|
||||
reduce: ("aProjection{}).reduceQuotaRemoved,
|
||||
want: wantReduce{
|
||||
aggregateType: eventstore.AggregateType("quota"),
|
||||
sequence: 15,
|
||||
previousSequence: 10,
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "DELETE FROM projections.quotas_periods WHERE (instance_id = $1) AND (unit = $2)",
|
||||
expectedArgs: []interface{}{
|
||||
"instance-id",
|
||||
quota.RequestsAllAuthenticated,
|
||||
},
|
||||
},
|
||||
{
|
||||
expectedStmt: "DELETE FROM projections.quotas_notifications WHERE (instance_id = $1) AND (unit = $2)",
|
||||
expectedArgs: []interface{}{
|
||||
"instance-id",
|
||||
quota.RequestsAllAuthenticated,
|
||||
},
|
||||
},
|
||||
{
|
||||
expectedStmt: "DELETE FROM projections.quotas WHERE (instance_id = $1) AND (unit = $2)",
|
||||
expectedArgs: []interface{}{
|
||||
"instance-id",
|
||||
quota.RequestsAllAuthenticated,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, {
|
||||
name: "reduceInstanceRemoved",
|
||||
args: args{
|
||||
event: getEvent(testEvent(
|
||||
repository.EventType(instance.InstanceRemovedEventType),
|
||||
instance.AggregateType,
|
||||
[]byte(`{
|
||||
"name": "name"
|
||||
}`),
|
||||
), instance.InstanceRemovedEventMapper),
|
||||
},
|
||||
reduce: ("aProjection{}).reduceInstanceRemoved,
|
||||
want: wantReduce{
|
||||
aggregateType: eventstore.AggregateType("instance"),
|
||||
sequence: 15,
|
||||
previousSequence: 10,
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "DELETE FROM projections.quotas_periods WHERE (instance_id = $1)",
|
||||
expectedArgs: []interface{}{
|
||||
"instance-id",
|
||||
},
|
||||
},
|
||||
{
|
||||
expectedStmt: "DELETE FROM projections.quotas_notifications WHERE (instance_id = $1)",
|
||||
expectedArgs: []interface{}{
|
||||
"instance-id",
|
||||
},
|
||||
},
|
||||
{
|
||||
expectedStmt: "DELETE FROM projections.quotas WHERE (instance_id = $1)",
|
||||
expectedArgs: []interface{}{
|
||||
"instance-id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
event := baseEvent(t)
|
||||
got, err := tt.reduce(event)
|
||||
if !errors.IsErrorInvalidArgument(err) {
|
||||
t.Errorf("no wrong event mapping: %v, got: %v", err, got)
|
||||
}
|
||||
event = tt.args.event(t)
|
||||
got, err = tt.reduce(event)
|
||||
assertReduce(t, got, err, QuotasProjectionTable, tt.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_quotaProjection_IncrementUsage(t *testing.T) {
|
||||
testNow := time.Now()
|
||||
type fields struct {
|
||||
client *database.DB
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
unit quota.Unit
|
||||
instanceID string
|
||||
periodStart time.Time
|
||||
count uint64
|
||||
}
|
||||
type res struct {
|
||||
sum uint64
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "",
|
||||
fields: fields{
|
||||
client: func() *database.DB {
|
||||
db, mock, _ := sqlmock.New()
|
||||
mock.ExpectQuery(regexp.QuoteMeta(incrementQuotaStatement)).
|
||||
WithArgs(
|
||||
"instance_id",
|
||||
1,
|
||||
testNow,
|
||||
2,
|
||||
).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"key"}).
|
||||
AddRow(3))
|
||||
return &database.DB{DB: db}
|
||||
}(),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
unit: quota.RequestsAllAuthenticated,
|
||||
instanceID: "instance_id",
|
||||
periodStart: testNow,
|
||||
count: 2,
|
||||
},
|
||||
res: res{
|
||||
sum: 3,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
q := "aProjection{
|
||||
client: tt.fields.client,
|
||||
}
|
||||
gotSum, err := q.IncrementUsage(tt.args.ctx, tt.args.unit, tt.args.instanceID, tt.args.periodStart, tt.args.count)
|
||||
assert.Equal(t, tt.res.sum, gotSum)
|
||||
assert.ErrorIs(t, err, tt.res.err)
|
||||
})
|
||||
}
|
||||
}
|
@ -26,6 +26,7 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/repository/oidcsession"
|
||||
"github.com/zitadel/zitadel/internal/repository/org"
|
||||
"github.com/zitadel/zitadel/internal/repository/project"
|
||||
"github.com/zitadel/zitadel/internal/repository/quota"
|
||||
"github.com/zitadel/zitadel/internal/repository/session"
|
||||
usr_repo "github.com/zitadel/zitadel/internal/repository/user"
|
||||
"github.com/zitadel/zitadel/internal/repository/usergrant"
|
||||
@ -93,6 +94,7 @@ func StartQueries(
|
||||
idpintent.RegisterEventMappers(repo.eventstore)
|
||||
authrequest.RegisterEventMappers(repo.eventstore)
|
||||
oidcsession.RegisterEventMappers(repo.eventstore)
|
||||
quota.RegisterEventMappers(repo.eventstore)
|
||||
|
||||
repo.idpConfigEncryption = idpConfigEncryption
|
||||
repo.multifactors = domain.MultifactorConfigs{
|
||||
|
121
internal/query/quota.go
Normal file
121
internal/query/quota.go
Normal file
@ -0,0 +1,121 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
errs "errors"
|
||||
"time"
|
||||
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/query/projection"
|
||||
"github.com/zitadel/zitadel/internal/repository/quota"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
var (
|
||||
quotasTable = table{
|
||||
name: projection.QuotasProjectionTable,
|
||||
instanceIDCol: projection.QuotaColumnInstanceID,
|
||||
}
|
||||
QuotaColumnID = Column{
|
||||
name: projection.QuotaColumnID,
|
||||
table: quotasTable,
|
||||
}
|
||||
QuotaColumnInstanceID = Column{
|
||||
name: projection.QuotaColumnInstanceID,
|
||||
table: quotasTable,
|
||||
}
|
||||
QuotaColumnUnit = Column{
|
||||
name: projection.QuotaColumnUnit,
|
||||
table: quotasTable,
|
||||
}
|
||||
QuotaColumnAmount = Column{
|
||||
name: projection.QuotaColumnAmount,
|
||||
table: quotasTable,
|
||||
}
|
||||
QuotaColumnLimit = Column{
|
||||
name: projection.QuotaColumnLimit,
|
||||
table: quotasTable,
|
||||
}
|
||||
QuotaColumnInterval = Column{
|
||||
name: projection.QuotaColumnInterval,
|
||||
table: quotasTable,
|
||||
}
|
||||
QuotaColumnFrom = Column{
|
||||
name: projection.QuotaColumnFrom,
|
||||
table: quotasTable,
|
||||
}
|
||||
)
|
||||
|
||||
type Quota struct {
|
||||
ID string
|
||||
From time.Time
|
||||
ResetInterval time.Duration
|
||||
Amount uint64
|
||||
Limit bool
|
||||
CurrentPeriodStart time.Time
|
||||
}
|
||||
|
||||
func (q *Queries) GetQuota(ctx context.Context, instanceID string, unit quota.Unit) (qu *Quota, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
query, scan := prepareQuotaQuery(ctx, q.client)
|
||||
stmt, args, err := query.Where(
|
||||
sq.Eq{
|
||||
QuotaColumnInstanceID.identifier(): instanceID,
|
||||
QuotaColumnUnit.identifier(): unit,
|
||||
},
|
||||
).ToSql()
|
||||
if err != nil {
|
||||
return nil, errors.ThrowInternal(err, "QUERY-XmYn9", "Errors.Query.SQLStatement")
|
||||
}
|
||||
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
|
||||
qu, err = scan(row)
|
||||
return err
|
||||
}, stmt, args...)
|
||||
return qu, err
|
||||
}
|
||||
|
||||
func prepareQuotaQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Quota, error)) {
|
||||
return sq.
|
||||
Select(
|
||||
QuotaColumnID.identifier(),
|
||||
QuotaColumnFrom.identifier(),
|
||||
QuotaColumnInterval.identifier(),
|
||||
QuotaColumnAmount.identifier(),
|
||||
QuotaColumnLimit.identifier(),
|
||||
"now()",
|
||||
).
|
||||
From(quotasTable.identifier()).
|
||||
PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*Quota, error) {
|
||||
q := new(Quota)
|
||||
var interval database.Duration
|
||||
var now time.Time
|
||||
err := row.Scan(&q.ID, &q.From, &interval, &q.Amount, &q.Limit, &now)
|
||||
if err != nil {
|
||||
if errs.Is(err, sql.ErrNoRows) {
|
||||
return nil, errors.ThrowNotFound(err, "QUERY-rDTM6", "Errors.Quota.NotExisting")
|
||||
}
|
||||
return nil, errors.ThrowInternal(err, "QUERY-LqySK", "Errors.Internal")
|
||||
}
|
||||
q.ResetInterval = time.Duration(interval)
|
||||
q.CurrentPeriodStart = pushPeriodStart(q.From, q.ResetInterval, now)
|
||||
return q, nil
|
||||
}
|
||||
}
|
||||
|
||||
func pushPeriodStart(from time.Time, interval time.Duration, now time.Time) time.Time {
|
||||
if now.IsZero() {
|
||||
now = time.Now()
|
||||
}
|
||||
for {
|
||||
next := from.Add(interval)
|
||||
if next.After(now) {
|
||||
return from
|
||||
}
|
||||
from = next
|
||||
}
|
||||
}
|
@ -1,55 +0,0 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/quota"
|
||||
)
|
||||
|
||||
type quotaReadModel struct {
|
||||
eventstore.ReadModel
|
||||
unit quota.Unit
|
||||
active bool
|
||||
config *quota.AddedEvent
|
||||
}
|
||||
|
||||
// newQuotaReadModel aggregateId is filled by reducing unit matching events
|
||||
func newQuotaReadModel(instanceId, resourceOwner string, unit quota.Unit) *quotaReadModel {
|
||||
return "aReadModel{
|
||||
ReadModel: eventstore.ReadModel{
|
||||
InstanceID: instanceId,
|
||||
ResourceOwner: resourceOwner,
|
||||
},
|
||||
unit: unit,
|
||||
}
|
||||
}
|
||||
|
||||
func (rm *quotaReadModel) Query() *eventstore.SearchQueryBuilder {
|
||||
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
ResourceOwner(rm.ResourceOwner).
|
||||
AllowTimeTravel().
|
||||
AddQuery().
|
||||
InstanceID(rm.InstanceID).
|
||||
AggregateTypes(quota.AggregateType).
|
||||
EventTypes(
|
||||
quota.AddedEventType,
|
||||
quota.RemovedEventType,
|
||||
).EventData(map[string]interface{}{"unit": rm.unit})
|
||||
|
||||
return query.Builder()
|
||||
}
|
||||
|
||||
func (rm *quotaReadModel) Reduce() error {
|
||||
for _, event := range rm.Events {
|
||||
switch e := event.(type) {
|
||||
case *quota.AddedEvent:
|
||||
rm.AggregateID = e.Aggregate().ID
|
||||
rm.active = true
|
||||
rm.config = e
|
||||
case *quota.RemovedEvent:
|
||||
rm.AggregateID = e.Aggregate().ID
|
||||
rm.active = false
|
||||
rm.config = nil
|
||||
}
|
||||
}
|
||||
return rm.ReadModel.Reduce()
|
||||
}
|
@ -2,58 +2,180 @@ package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
errs "errors"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/call"
|
||||
zitadel_errors "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/query/projection"
|
||||
"github.com/zitadel/zitadel/internal/repository/quota"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
func (q *Queries) GetDueQuotaNotifications(ctx context.Context, config *quota.AddedEvent, periodStart time.Time, usedAbs uint64) ([]*quota.NotificationDueEvent, error) {
|
||||
if len(config.Notifications) == 0 {
|
||||
var (
|
||||
quotaNotificationsTable = table{
|
||||
name: projection.QuotaNotificationsTable,
|
||||
instanceIDCol: projection.QuotaNotificationColumnInstanceID,
|
||||
}
|
||||
QuotaNotificationColumnInstanceID = Column{
|
||||
name: projection.QuotaNotificationColumnInstanceID,
|
||||
table: quotaNotificationsTable,
|
||||
}
|
||||
QuotaNotificationColumnUnit = Column{
|
||||
name: projection.QuotaNotificationColumnUnit,
|
||||
table: quotaNotificationsTable,
|
||||
}
|
||||
QuotaNotificationColumnID = Column{
|
||||
name: projection.QuotaNotificationColumnID,
|
||||
table: quotaNotificationsTable,
|
||||
}
|
||||
QuotaNotificationColumnCallURL = Column{
|
||||
name: projection.QuotaNotificationColumnCallURL,
|
||||
table: quotaNotificationsTable,
|
||||
}
|
||||
QuotaNotificationColumnPercent = Column{
|
||||
name: projection.QuotaNotificationColumnPercent,
|
||||
table: quotaNotificationsTable,
|
||||
}
|
||||
QuotaNotificationColumnRepeat = Column{
|
||||
name: projection.QuotaNotificationColumnRepeat,
|
||||
table: quotaNotificationsTable,
|
||||
}
|
||||
QuotaNotificationColumnLatestDuePeriodStart = Column{
|
||||
name: projection.QuotaNotificationColumnLatestDuePeriodStart,
|
||||
table: quotaNotificationsTable,
|
||||
}
|
||||
QuotaNotificationColumnNextDueThreshold = Column{
|
||||
name: projection.QuotaNotificationColumnNextDueThreshold,
|
||||
table: quotaNotificationsTable,
|
||||
}
|
||||
)
|
||||
|
||||
func (q *Queries) GetDueQuotaNotifications(ctx context.Context, instanceID string, unit quota.Unit, qu *Quota, periodStart time.Time, usedAbs uint64) (dueNotifications []*quota.NotificationDueEvent, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
usedRel := uint16(math.Floor(float64(usedAbs*100) / float64(qu.Amount)))
|
||||
query, scan := prepareQuotaNotificationsQuery(ctx, q.client)
|
||||
stmt, args, err := query.Where(
|
||||
sq.And{
|
||||
sq.Eq{
|
||||
QuotaNotificationColumnInstanceID.identifier(): instanceID,
|
||||
QuotaNotificationColumnUnit.identifier(): unit,
|
||||
},
|
||||
sq.Or{
|
||||
// If the relative usage is greater than the next due threshold in the current period, it's clear we can notify
|
||||
sq.And{
|
||||
sq.Eq{QuotaNotificationColumnLatestDuePeriodStart.identifier(): periodStart},
|
||||
sq.LtOrEq{QuotaNotificationColumnNextDueThreshold.identifier(): usedRel},
|
||||
},
|
||||
// In case we haven't seen a due notification for this quota period, we compare against the configured percent
|
||||
sq.And{
|
||||
sq.Or{
|
||||
sq.Expr(QuotaNotificationColumnLatestDuePeriodStart.identifier() + " IS NULL"),
|
||||
sq.NotEq{QuotaNotificationColumnLatestDuePeriodStart.identifier(): periodStart},
|
||||
},
|
||||
sq.LtOrEq{QuotaNotificationColumnPercent.identifier(): usedRel},
|
||||
},
|
||||
},
|
||||
},
|
||||
).ToSql()
|
||||
if err != nil {
|
||||
return nil, zitadel_errors.ThrowInternal(err, "QUERY-XmYn9", "Errors.Query.SQLStatement")
|
||||
}
|
||||
var notifications *QuotaNotifications
|
||||
err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
|
||||
notifications, err = scan(rows)
|
||||
return err
|
||||
}, stmt, args...)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
aggregate := config.Aggregate()
|
||||
wm, err := q.getQuotaNotificationsReadModel(ctx, aggregate, periodStart)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
usedRel := uint16(math.Floor(float64(usedAbs*100) / float64(config.Amount)))
|
||||
|
||||
var dueNotifications []*quota.NotificationDueEvent
|
||||
for _, notification := range config.Notifications {
|
||||
if notification.Percent > usedRel {
|
||||
for _, notification := range notifications.Configs {
|
||||
reachedThreshold := calculateThreshold(usedRel, notification.Percent)
|
||||
if !notification.Repeat && notification.Percent < reachedThreshold {
|
||||
continue
|
||||
}
|
||||
|
||||
threshold := notification.Percent
|
||||
if notification.Repeat {
|
||||
threshold = uint16(math.Max(1, math.Floor(float64(usedRel)/float64(notification.Percent)))) * notification.Percent
|
||||
}
|
||||
|
||||
if wm.latestDueThresholds[notification.ID] < threshold {
|
||||
dueNotifications = append(
|
||||
dueNotifications,
|
||||
quota.NewNotificationDueEvent(
|
||||
ctx,
|
||||
&aggregate,
|
||||
config.Unit,
|
||||
"a.NewAggregate(qu.ID, instanceID).Aggregate,
|
||||
unit,
|
||||
notification.ID,
|
||||
notification.CallURL,
|
||||
periodStart,
|
||||
threshold,
|
||||
reachedThreshold,
|
||||
usedAbs,
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return dueNotifications, nil
|
||||
}
|
||||
|
||||
func (q *Queries) getQuotaNotificationsReadModel(ctx context.Context, aggregate eventstore.Aggregate, periodStart time.Time) (*quotaNotificationsReadModel, error) {
|
||||
wm := newQuotaNotificationsReadModel(aggregate.ID, aggregate.InstanceID, aggregate.ResourceOwner, periodStart)
|
||||
return wm, q.eventstore.FilterToQueryReducer(ctx, wm)
|
||||
type QuotaNotification struct {
|
||||
ID string
|
||||
CallURL string
|
||||
Percent uint16
|
||||
Repeat bool
|
||||
NextDueThreshold uint16
|
||||
}
|
||||
|
||||
type QuotaNotifications struct {
|
||||
SearchResponse
|
||||
Configs []*QuotaNotification
|
||||
}
|
||||
|
||||
// calculateThreshold calculates the nearest reached threshold.
|
||||
// It makes sure that the percent configured on the notification is calculated within the "current" 100%,
|
||||
// e.g. when configuring 80%, the thresholds are 80, 180, 280, ...
|
||||
// so 170% use is always 70% of the current bucket, with the above config, the reached threshold would be 80.
|
||||
func calculateThreshold(usedRel, notificationPercent uint16) uint16 {
|
||||
// check how many times we reached 100%
|
||||
times := math.Floor(float64(usedRel) / 100)
|
||||
// check how many times we reached the percent configured with the "current" 100%
|
||||
percent := math.Floor(float64(usedRel%100) / float64(notificationPercent))
|
||||
// If neither is reached, directly return 0.
|
||||
// This way we don't end up in some wrong uint16 range in the calculation below.
|
||||
if times == 0 && percent == 0 {
|
||||
return 0
|
||||
}
|
||||
return uint16(times+percent-1)*100 + notificationPercent
|
||||
}
|
||||
|
||||
func prepareQuotaNotificationsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*QuotaNotifications, error)) {
|
||||
return sq.Select(
|
||||
QuotaNotificationColumnID.identifier(),
|
||||
QuotaNotificationColumnCallURL.identifier(),
|
||||
QuotaNotificationColumnPercent.identifier(),
|
||||
QuotaNotificationColumnRepeat.identifier(),
|
||||
QuotaNotificationColumnNextDueThreshold.identifier(),
|
||||
).
|
||||
From(quotaNotificationsTable.identifier() + db.Timetravel(call.Took(ctx))).
|
||||
PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*QuotaNotifications, error) {
|
||||
cfgs := &QuotaNotifications{Configs: []*QuotaNotification{}}
|
||||
for rows.Next() {
|
||||
cfg := new(QuotaNotification)
|
||||
var nextDueThreshold sql.NullInt16
|
||||
err := rows.Scan(&cfg.ID, &cfg.CallURL, &cfg.Percent, &cfg.Repeat, &nextDueThreshold)
|
||||
if err != nil {
|
||||
if errs.Is(err, sql.ErrNoRows) {
|
||||
return nil, zitadel_errors.ThrowNotFound(err, "QUERY-bbqWb", "Errors.QuotaNotification.NotExisting")
|
||||
}
|
||||
return nil, zitadel_errors.ThrowInternal(err, "QUERY-8copS", "Errors.Internal")
|
||||
}
|
||||
if nextDueThreshold.Valid {
|
||||
cfg.NextDueThreshold = uint16(nextDueThreshold.Int16)
|
||||
}
|
||||
cfgs.Configs = append(cfgs.Configs, cfg)
|
||||
}
|
||||
return cfgs, nil
|
||||
}
|
||||
}
|
||||
|
@ -1,46 +0,0 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/quota"
|
||||
)
|
||||
|
||||
type quotaNotificationsReadModel struct {
|
||||
eventstore.ReadModel
|
||||
periodStart time.Time
|
||||
latestDueThresholds map[string]uint16
|
||||
}
|
||||
|
||||
func newQuotaNotificationsReadModel(aggregateId, instanceId, resourceOwner string, periodStart time.Time) *quotaNotificationsReadModel {
|
||||
return "aNotificationsReadModel{
|
||||
ReadModel: eventstore.ReadModel{
|
||||
AggregateID: aggregateId,
|
||||
InstanceID: instanceId,
|
||||
ResourceOwner: resourceOwner,
|
||||
},
|
||||
periodStart: periodStart,
|
||||
latestDueThresholds: make(map[string]uint16),
|
||||
}
|
||||
}
|
||||
|
||||
func (rm *quotaNotificationsReadModel) Query() *eventstore.SearchQueryBuilder {
|
||||
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
ResourceOwner(rm.ResourceOwner).
|
||||
AllowTimeTravel().
|
||||
AddQuery().
|
||||
InstanceID(rm.InstanceID).
|
||||
AggregateTypes(quota.AggregateType).
|
||||
AggregateIDs(rm.AggregateID).
|
||||
CreationDateAfter(rm.periodStart).
|
||||
EventTypes(quota.NotificationDueEventType).Builder()
|
||||
}
|
||||
|
||||
func (rm *quotaNotificationsReadModel) Reduce() error {
|
||||
for _, event := range rm.Events {
|
||||
e := event.(*quota.NotificationDueEvent)
|
||||
rm.latestDueThresholds[e.ID] = e.Threshold
|
||||
}
|
||||
return rm.ReadModel.Reduce()
|
||||
}
|
181
internal/query/quota_notifications_test.go
Normal file
181
internal/query/quota_notifications_test.go
Normal file
@ -0,0 +1,181 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_calculateThreshold(t *testing.T) {
|
||||
type args struct {
|
||||
usedRel uint16
|
||||
notificationPercent uint16
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want uint16
|
||||
}{
|
||||
{
|
||||
name: "80 - below configuration",
|
||||
args: args{
|
||||
usedRel: 70,
|
||||
notificationPercent: 80,
|
||||
},
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "80 - below 100 percent use",
|
||||
args: args{
|
||||
usedRel: 90,
|
||||
notificationPercent: 80,
|
||||
},
|
||||
want: 80,
|
||||
},
|
||||
{
|
||||
name: "80 - above 100 percent use",
|
||||
args: args{
|
||||
usedRel: 120,
|
||||
notificationPercent: 80,
|
||||
},
|
||||
want: 80,
|
||||
},
|
||||
{
|
||||
name: "80 - more than twice the use",
|
||||
args: args{
|
||||
usedRel: 190,
|
||||
notificationPercent: 80,
|
||||
},
|
||||
want: 180,
|
||||
},
|
||||
{
|
||||
name: "100 - below 100 percent use",
|
||||
args: args{
|
||||
usedRel: 90,
|
||||
notificationPercent: 100,
|
||||
},
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "100 - above 100 percent use",
|
||||
args: args{
|
||||
usedRel: 120,
|
||||
notificationPercent: 100,
|
||||
},
|
||||
want: 100,
|
||||
},
|
||||
{
|
||||
name: "100 - more than twice the use",
|
||||
args: args{
|
||||
usedRel: 210,
|
||||
notificationPercent: 100,
|
||||
},
|
||||
want: 200,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := calculateThreshold(tt.args.usedRel, tt.args.notificationPercent)
|
||||
assert.Equal(t, int(tt.want), int(got))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
expectedQuotaNotificationsQuery = regexp.QuoteMeta(`SELECT projections.quotas_notifications.id,` +
|
||||
` projections.quotas_notifications.call_url,` +
|
||||
` projections.quotas_notifications.percent,` +
|
||||
` projections.quotas_notifications.repeat,` +
|
||||
` projections.quotas_notifications.next_due_threshold` +
|
||||
` FROM projections.quotas_notifications` +
|
||||
` AS OF SYSTEM TIME '-1 ms'`)
|
||||
|
||||
quotaNotificationsCols = []string{
|
||||
"id",
|
||||
"call_url",
|
||||
"percent",
|
||||
"repeat",
|
||||
"next_due_threshold",
|
||||
}
|
||||
)
|
||||
|
||||
func Test_prepareQuotaNotificationsQuery(t *testing.T) {
|
||||
type want struct {
|
||||
sqlExpectations sqlExpectation
|
||||
err checkErr
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prepare interface{}
|
||||
want want
|
||||
object interface{}
|
||||
}{
|
||||
{
|
||||
name: "prepareQuotaNotificationsQuery no result",
|
||||
prepare: prepareQuotaNotificationsQuery,
|
||||
want: want{
|
||||
sqlExpectations: mockQueries(
|
||||
expectedQuotaNotificationsQuery,
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
},
|
||||
object: &QuotaNotifications{Configs: []*QuotaNotification{}},
|
||||
},
|
||||
{
|
||||
name: "prepareQuotaNotificationsQuery",
|
||||
prepare: prepareQuotaNotificationsQuery,
|
||||
want: want{
|
||||
sqlExpectations: mockQuery(
|
||||
expectedQuotaNotificationsQuery,
|
||||
quotaNotificationsCols,
|
||||
[]driver.Value{
|
||||
"quota-id",
|
||||
"url",
|
||||
uint16(100),
|
||||
true,
|
||||
uint16(100),
|
||||
},
|
||||
),
|
||||
},
|
||||
object: &QuotaNotifications{
|
||||
Configs: []*QuotaNotification{
|
||||
{
|
||||
ID: "quota-id",
|
||||
CallURL: "url",
|
||||
Percent: 100,
|
||||
Repeat: true,
|
||||
NextDueThreshold: 100,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "prepareQuotaNotificationsQuery sql err",
|
||||
prepare: prepareQuotaNotificationsQuery,
|
||||
want: want{
|
||||
sqlExpectations: mockQueryErr(
|
||||
expectedQuotaNotificationsQuery,
|
||||
sql.ErrConnDone,
|
||||
),
|
||||
err: func(err error) (error, bool) {
|
||||
if !errors.Is(err, sql.ErrConnDone) {
|
||||
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
|
||||
}
|
||||
return nil, true
|
||||
},
|
||||
},
|
||||
object: (*Quota)(nil),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
|
||||
})
|
||||
}
|
||||
}
|
@ -1,30 +0,0 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/repository/quota"
|
||||
)
|
||||
|
||||
func (q *Queries) GetCurrentQuotaPeriod(ctx context.Context, instanceID string, unit quota.Unit) (*quota.AddedEvent, time.Time, error) {
|
||||
rm, err := q.getQuotaReadModel(ctx, instanceID, instanceID, unit)
|
||||
if err != nil || !rm.active {
|
||||
return nil, time.Time{}, err
|
||||
}
|
||||
|
||||
return rm.config, pushPeriodStart(rm.config.From, rm.config.ResetInterval, time.Now()), nil
|
||||
}
|
||||
|
||||
func pushPeriodStart(from time.Time, interval time.Duration, now time.Time) time.Time {
|
||||
next := from.Add(interval)
|
||||
if next.After(now) {
|
||||
return from
|
||||
}
|
||||
return pushPeriodStart(next, interval, now)
|
||||
}
|
||||
|
||||
func (q *Queries) getQuotaReadModel(ctx context.Context, instanceId, resourceOwner string, unit quota.Unit) (*quotaReadModel, error) {
|
||||
rm := newQuotaReadModel(instanceId, resourceOwner, unit)
|
||||
return rm, q.eventstore.FilterToQueryReducer(ctx, rm)
|
||||
}
|
86
internal/query/quota_periods.go
Normal file
86
internal/query/quota_periods.go
Normal file
@ -0,0 +1,86 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/call"
|
||||
zitadel_errors "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/query/projection"
|
||||
"github.com/zitadel/zitadel/internal/repository/quota"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
var (
|
||||
quotaPeriodsTable = table{
|
||||
name: projection.QuotaPeriodsProjectionTable,
|
||||
instanceIDCol: projection.QuotaColumnInstanceID,
|
||||
}
|
||||
QuotaPeriodColumnInstanceID = Column{
|
||||
name: projection.QuotaPeriodColumnInstanceID,
|
||||
table: quotaPeriodsTable,
|
||||
}
|
||||
QuotaPeriodColumnUnit = Column{
|
||||
name: projection.QuotaPeriodColumnUnit,
|
||||
table: quotaPeriodsTable,
|
||||
}
|
||||
QuotaPeriodColumnStart = Column{
|
||||
name: projection.QuotaPeriodColumnStart,
|
||||
table: quotaPeriodsTable,
|
||||
}
|
||||
QuotaPeriodColumnUsage = Column{
|
||||
name: projection.QuotaPeriodColumnUsage,
|
||||
table: quotaPeriodsTable,
|
||||
}
|
||||
)
|
||||
|
||||
func (q *Queries) GetRemainingQuotaUsage(ctx context.Context, instanceID string, unit quota.Unit) (remaining *uint64, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
stmt, scan := prepareRemainingQuotaUsageQuery(ctx, q.client)
|
||||
query, args, err := stmt.Where(
|
||||
sq.And{
|
||||
sq.Eq{
|
||||
QuotaPeriodColumnInstanceID.identifier(): instanceID,
|
||||
QuotaPeriodColumnUnit.identifier(): unit,
|
||||
QuotaColumnLimit.identifier(): true,
|
||||
},
|
||||
sq.Expr("age(" + QuotaPeriodColumnStart.identifier() + ") < " + QuotaColumnInterval.identifier()),
|
||||
sq.Expr(QuotaPeriodColumnStart.identifier() + " < now()"),
|
||||
}).
|
||||
ToSql()
|
||||
if err != nil {
|
||||
return nil, zitadel_errors.ThrowInternal(err, "QUERY-FSA3g", "Errors.Query.SQLStatement")
|
||||
}
|
||||
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
|
||||
remaining, err = scan(row)
|
||||
return err
|
||||
}, query, args...)
|
||||
if zitadel_errors.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return remaining, err
|
||||
}
|
||||
|
||||
func prepareRemainingQuotaUsageQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*uint64, error)) {
|
||||
return sq.
|
||||
Select(
|
||||
"greatest(0, " + QuotaColumnAmount.identifier() + "-" + QuotaPeriodColumnUsage.identifier() + ")",
|
||||
).
|
||||
From(quotaPeriodsTable.identifier()).
|
||||
Join(join(QuotaColumnUnit, QuotaPeriodColumnUnit) + db.Timetravel(call.Took(ctx))).
|
||||
PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*uint64, error) {
|
||||
usage := new(uint64)
|
||||
err := row.Scan(usage)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, zitadel_errors.ThrowNotFound(err, "QUERY-quiowi2", "Errors.Internal")
|
||||
}
|
||||
return nil, zitadel_errors.ThrowInternal(err, "QUERY-81j1jn2", "Errors.Internal")
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
}
|
95
internal/query/quota_periods_test.go
Normal file
95
internal/query/quota_periods_test.go
Normal file
@ -0,0 +1,95 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
errs "github.com/zitadel/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
expectedRemainingQuotaUsageQuery = regexp.QuoteMeta(`SELECT greatest(0, projections.quotas.amount-projections.quotas_periods.usage)` +
|
||||
` FROM projections.quotas_periods` +
|
||||
` JOIN projections.quotas ON projections.quotas_periods.unit = projections.quotas.unit AND projections.quotas_periods.instance_id = projections.quotas.instance_id` +
|
||||
` AS OF SYSTEM TIME '-1 ms'`)
|
||||
remainingQuotaUsageCols = []string{
|
||||
"usage",
|
||||
}
|
||||
)
|
||||
|
||||
func Test_prepareRemainingQuotaUsageQuery(t *testing.T) {
|
||||
type want struct {
|
||||
sqlExpectations sqlExpectation
|
||||
err checkErr
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prepare interface{}
|
||||
want want
|
||||
object interface{}
|
||||
}{
|
||||
{
|
||||
name: "prepareRemainingQuotaUsageQuery no result",
|
||||
prepare: prepareRemainingQuotaUsageQuery,
|
||||
want: want{
|
||||
sqlExpectations: mockQueryScanErr(
|
||||
expectedRemainingQuotaUsageQuery,
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
err: func(err error) (error, bool) {
|
||||
if !errs.IsNotFound(err) {
|
||||
return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false
|
||||
}
|
||||
return nil, true
|
||||
},
|
||||
},
|
||||
object: (*uint64)(nil),
|
||||
},
|
||||
{
|
||||
name: "prepareRemainingQuotaUsageQuery",
|
||||
prepare: prepareRemainingQuotaUsageQuery,
|
||||
want: want{
|
||||
sqlExpectations: mockQuery(
|
||||
expectedRemainingQuotaUsageQuery,
|
||||
remainingQuotaUsageCols,
|
||||
[]driver.Value{
|
||||
uint64(100),
|
||||
},
|
||||
),
|
||||
},
|
||||
object: uint64P(100),
|
||||
},
|
||||
{
|
||||
name: "prepareRemainingQuotaUsageQuery sql err",
|
||||
prepare: prepareRemainingQuotaUsageQuery,
|
||||
want: want{
|
||||
sqlExpectations: mockQueryErr(
|
||||
expectedRemainingQuotaUsageQuery,
|
||||
sql.ErrConnDone,
|
||||
),
|
||||
err: func(err error) (error, bool) {
|
||||
if !errors.Is(err, sql.ErrConnDone) {
|
||||
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
|
||||
}
|
||||
return nil, true
|
||||
},
|
||||
},
|
||||
object: (*uint64)(nil),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func uint64P(i int) *uint64 {
|
||||
u := uint64(i)
|
||||
return &u
|
||||
}
|
127
internal/query/quota_test.go
Normal file
127
internal/query/quota_test.go
Normal file
@ -0,0 +1,127 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgtype"
|
||||
|
||||
errs "github.com/zitadel/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
expectedQuotaQuery = regexp.QuoteMeta(`SELECT projections.quotas.id,` +
|
||||
` projections.quotas.from_anchor,` +
|
||||
` projections.quotas.interval,` +
|
||||
` projections.quotas.amount,` +
|
||||
` projections.quotas.limit_usage,` +
|
||||
` now()` +
|
||||
` FROM projections.quotas`)
|
||||
|
||||
quotaCols = []string{
|
||||
"id",
|
||||
"from_anchor",
|
||||
"interval",
|
||||
"amount",
|
||||
"limit_usage",
|
||||
"now",
|
||||
}
|
||||
)
|
||||
|
||||
func dayNow() time.Time {
|
||||
return time.Now().Truncate(24 * time.Hour)
|
||||
}
|
||||
|
||||
func interval(t *testing.T, src time.Duration) pgtype.Interval {
|
||||
interval := pgtype.Interval{}
|
||||
err := interval.Set(src)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return interval
|
||||
}
|
||||
|
||||
func Test_QuotaPrepare(t *testing.T) {
|
||||
type want struct {
|
||||
sqlExpectations sqlExpectation
|
||||
err checkErr
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prepare interface{}
|
||||
want want
|
||||
object interface{}
|
||||
}{
|
||||
{
|
||||
name: "prepareQuotaQuery no result",
|
||||
prepare: prepareQuotaQuery,
|
||||
want: want{
|
||||
sqlExpectations: mockQueriesScanErr(
|
||||
expectedQuotaQuery,
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
err: func(err error) (error, bool) {
|
||||
if !errs.IsNotFound(err) {
|
||||
return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false
|
||||
}
|
||||
return nil, true
|
||||
},
|
||||
},
|
||||
object: (*Quota)(nil),
|
||||
},
|
||||
{
|
||||
name: "prepareQuotaQuery",
|
||||
prepare: prepareQuotaQuery,
|
||||
want: want{
|
||||
sqlExpectations: mockQuery(
|
||||
expectedQuotaQuery,
|
||||
quotaCols,
|
||||
[]driver.Value{
|
||||
"quota-id",
|
||||
dayNow(),
|
||||
interval(t, time.Hour*24),
|
||||
uint64(1000),
|
||||
true,
|
||||
testNow,
|
||||
},
|
||||
),
|
||||
},
|
||||
object: &Quota{
|
||||
ID: "quota-id",
|
||||
From: dayNow(),
|
||||
ResetInterval: time.Hour * 24,
|
||||
CurrentPeriodStart: dayNow(),
|
||||
Amount: 1000,
|
||||
Limit: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "prepareQuotaQuery sql err",
|
||||
prepare: prepareQuotaQuery,
|
||||
want: want{
|
||||
sqlExpectations: mockQueryErr(
|
||||
expectedQuotaQuery,
|
||||
sql.ErrConnDone,
|
||||
),
|
||||
err: func(err error) (error, bool) {
|
||||
if !errors.Is(err, sql.ErrConnDone) {
|
||||
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
|
||||
}
|
||||
return nil, true
|
||||
},
|
||||
},
|
||||
object: (*Quota)(nil),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
|
||||
})
|
||||
}
|
||||
}
|
@ -13,14 +13,14 @@ type Aggregate struct {
|
||||
eventstore.Aggregate
|
||||
}
|
||||
|
||||
func NewAggregate(id, instanceId, resourceOwner string) *Aggregate {
|
||||
func NewAggregate(id, instanceId string) *Aggregate {
|
||||
return &Aggregate{
|
||||
Aggregate: eventstore.Aggregate{
|
||||
Type: AggregateType,
|
||||
Version: AggregateVersion,
|
||||
ID: id,
|
||||
InstanceID: instanceId,
|
||||
ResourceOwner: resourceOwner,
|
||||
ResourceOwner: instanceId,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -15,7 +15,6 @@ type Unit uint
|
||||
|
||||
const (
|
||||
UniqueQuotaNameType = "quota_units"
|
||||
UniqueQuotaNotificationIDType = "quota_notification"
|
||||
eventTypePrefix = eventstore.EventType("quota.")
|
||||
AddedEventType = eventTypePrefix + "added"
|
||||
NotifiedEventType = eventTypePrefix + "notified"
|
||||
|
Loading…
x
Reference in New Issue
Block a user