perf: project quotas and usages (#6441)

* project quota added

* project quota removed

* add periods table

* make log record generic

* accumulate usage

* query usage

* count action run seconds

* fix filter in ReportQuotaUsage

* fix existing tests

* fix logstore tests

* fix typo

* fix: add quota unit tests command side

* fix: add quota unit tests command side

* fix: add quota unit tests command side

* move notifications into debouncer and improve limit querying

* cleanup

* comment

* fix: add quota unit tests command side

* fix remaining quota usage query

* implement InmemLogStorage

* cleanup and linting

* improve test

* fix: add quota unit tests command side

* fix: add quota unit tests command side

* fix: add quota unit tests command side

* fix: add quota unit tests command side

* action notifications and fixes for notifications query

* revert console prefix

* fix: add quota unit tests command side

* fix: add quota integration tests

* improve accountable requests

* improve accountable requests

* fix: add quota integration tests

* fix: add quota integration tests

* fix: add quota integration tests

* comment

* remove ability to store logs in db and other changes requested from review

* changes requested from review

* changes requested from review

* Update internal/api/http/middleware/access_interceptor.go

Co-authored-by: Silvan <silvan.reusser@gmail.com>

* tests: fix quotas integration tests

* improve incrementUsageStatement

* linting

* fix: delete e2e tests as intergation tests cover functionality

* Update internal/api/http/middleware/access_interceptor.go

Co-authored-by: Silvan <silvan.reusser@gmail.com>

* backup

* fix conflict

* create rc

* create prerelease

* remove issue release labeling

* fix tracing

---------

Co-authored-by: Livio Spring <livio.a@gmail.com>
Co-authored-by: Stefan Benz <stefan@caos.ch>
Co-authored-by: adlerhurst <silvan.reusser@gmail.com>
(cherry picked from commit 1a49b7d298690ce64846727f1fcf5a325f77c76e)
This commit is contained in:
Elio Bischof 2023-09-15 16:58:45 +02:00 committed by Livio Spring
parent b688d6f842
commit 5823fdbef9
No known key found for this signature in database
GPG Key ID: 26BB1C2FA5952CF0
66 changed files with 3423 additions and 1413 deletions

View File

@ -452,54 +452,29 @@ Actions:
LogStore:
Access:
Database:
# If enabled, all access logs are stored in the database table logstore.access
Enabled: false # ZITADEL_LOGSTORE_ACCESS_DATABASE_ENABLED
# Logs that are older than the keep duration are cleaned up continuously
# 2160h are 90 days, 3 months
Keep: 2160h # ZITADEL_LOGSTORE_ACCESS_DATABASE_KEEP
# CleanupInterval defines the time between cleanup iterations
CleanupInterval: 4h # ZITADEL_LOGSTORE_ACCESS_DATABASE_CLEANUPINTERVAL
# Debouncing enables to asynchronously emit log entries, so the normal execution performance is not impaired
# Log entries are held in memory until one of the conditions MinFrequency or MaxBulkSize meets.
Debounce:
MinFrequency: 2m # ZITADEL_LOGSTORE_ACCESS_DATABASE_DEBOUNCE_MINFREQUENCY
MaxBulkSize: 100 # ZITADEL_LOGSTORE_ACCESS_DATABASE_DEBOUNCE_MAXBULKSIZE
Stdout:
# If enabled, all access logs are printed to the binary's standard output
Enabled: false # ZITADEL_LOGSTORE_ACCESS_STDOUT_ENABLED
# Debouncing enables to asynchronously emit log entries, so the normal execution performance is not impaired
# Log entries are held in memory until one of the conditions MinFrequency or MaxBulkSize meets.
Debounce:
MinFrequency: 0s # ZITADEL_LOGSTORE_ACCESS_STDOUT_DEBOUNCE_MINFREQUENCY
MaxBulkSize: 0 # ZITADEL_LOGSTORE_ACCESS_STDOUT_DEBOUNCE_MAXBULKSIZE
Execution:
Database:
# If enabled, all action execution logs are stored in the database table logstore.execution
Enabled: false # ZITADEL_LOGSTORE_EXECUTION_DATABASE_ENABLED
# Logs that are older than the keep duration are cleaned up continuously
# 2160h are 90 days, 3 months
Keep: 2160h # ZITADEL_LOGSTORE_EXECUTION_DATABASE_KEEP
# CleanupInterval defines the time between cleanup iterations
CleanupInterval: 4h # ZITADEL_LOGSTORE_EXECUTION_DATABASE_CLEANUPINTERVAL
# Debouncing enables to asynchronously emit log entries, so the normal execution performance is not impaired
# Log entries are held in memory until one of the conditions MinFrequency or MaxBulkSize meets.
Debounce:
MinFrequency: 0s # ZITADEL_LOGSTORE_EXECUTION_DATABASE_DEBOUNCE_MINFREQUENCY
MaxBulkSize: 0 # ZITADEL_LOGSTORE_EXECUTION_DATABASE_DEBOUNCE_MAXBULKSIZE
Stdout:
# If enabled, all execution logs are printed to the binary's standard output
Enabled: true # ZITADEL_LOGSTORE_EXECUTION_STDOUT_ENABLED
# Debouncing enables to asynchronously emit log entries, so the normal execution performance is not impaired
# Log entries are held in memory until one of the conditions MinFrequency or MaxBulkSize meets.
Debounce:
MinFrequency: 0s # ZITADEL_LOGSTORE_EXECUTION_STDOUT_DEBOUNCE_MINFREQUENCY
MaxBulkSize: 0 # ZITADEL_LOGSTORE_EXECUTION_STDOUT_DEBOUNCE_MAXBULKSIZE
Quotas:
Access:
# If enabled, authenticated requests are counted and potentially limited depending on the configured quota of the instance
Enabled: false # ZITADEL_QUOTAS_ACCESS_ENABLED
Debounce:
MinFrequency: 0s # ZITADEL_QUOTAS_ACCESS_DEBOUNCE_MINFREQUENCY
MaxBulkSize: 0 # ZITADEL_QUOTAS_ACCESS_DEBOUNCE_MAXBULKSIZE
ExhaustedCookieKey: "zitadel.quota.exhausted" # ZITADEL_QUOTAS_ACCESS_EXHAUSTEDCOOKIEKEY
ExhaustedCookieMaxAge: "300s" # ZITADEL_QUOTAS_ACCESS_EXHAUSTEDCOOKIEMAXAGE
Execution:
# If enabled, all action executions are counted and potentially limited depending on the configured quota of the instance
Enabled: false # ZITADEL_QUOTAS_EXECUTION_DATABASE_ENABLED
Debounce:
MinFrequency: 0s # ZITADEL_QUOTAS_EXECUTION_DEBOUNCE_MINFREQUENCY
MaxBulkSize: 0 # ZITADEL_QUOTAS_EXECUTION_DEBOUNCE_MAXBULKSIZE
Eventstore:
PushTimeout: 15s # ZITADEL_EVENTSTORE_PUSHTIMEOUT

View File

@ -72,7 +72,11 @@ type Config struct {
}
type QuotasConfig struct {
Access *middleware.AccessConfig
Access struct {
logstore.EmitterConfig `mapstructure:",squash"`
middleware.AccessConfig `mapstructure:",squash"`
}
Execution *logstore.EmitterConfig
}
func MustNewConfig(v *viper.Viper) *Config {

View File

@ -64,6 +64,8 @@ import (
"github.com/zitadel/zitadel/internal/logstore/emitters/access"
"github.com/zitadel/zitadel/internal/logstore/emitters/execution"
"github.com/zitadel/zitadel/internal/logstore/emitters/stdout"
"github.com/zitadel/zitadel/internal/logstore/record"
"github.com/zitadel/zitadel/internal/net"
"github.com/zitadel/zitadel/internal/notification"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/static"
@ -108,7 +110,6 @@ type Server struct {
AuthzRepo authz_repo.Repository
Storage static.Storage
Commands *command.Commands
LogStore *logstore.Service
Router *mux.Router
TLSConfig *tls.Config
Shutdown chan<- os.Signal
@ -209,17 +210,16 @@ func startZitadel(config *Config, masterKey string, server chan<- *Server) error
}
clock := clockpkg.New()
actionsExecutionStdoutEmitter, err := logstore.NewEmitter(ctx, clock, config.LogStore.Execution.Stdout, stdout.NewStdoutEmitter())
actionsExecutionStdoutEmitter, err := logstore.NewEmitter[*record.ExecutionLog](ctx, clock, &logstore.EmitterConfig{Enabled: config.LogStore.Execution.Stdout.Enabled}, stdout.NewStdoutEmitter[*record.ExecutionLog]())
if err != nil {
return err
}
actionsExecutionDBEmitter, err := logstore.NewEmitter(ctx, clock, config.LogStore.Execution.Database, execution.NewDatabaseLogStorage(dbClient))
actionsExecutionDBEmitter, err := logstore.NewEmitter[*record.ExecutionLog](ctx, clock, config.Quotas.Execution, execution.NewDatabaseLogStorage(dbClient, commands, queries))
if err != nil {
return err
}
usageReporter := logstore.UsageReporterFunc(commands.ReportQuotaUsage)
actionsLogstoreSvc := logstore.New(queries, usageReporter, actionsExecutionDBEmitter, actionsExecutionStdoutEmitter)
actionsLogstoreSvc := logstore.New(queries, actionsExecutionDBEmitter, actionsExecutionStdoutEmitter)
actions.SetLogstoreService(actionsLogstoreSvc)
notification.Start(
@ -259,8 +259,6 @@ func startZitadel(config *Config, masterKey string, server chan<- *Server) error
storage,
authZRepo,
keys,
queries,
usageReporter,
permissionCheck,
)
if err != nil {
@ -281,7 +279,6 @@ func startZitadel(config *Config, masterKey string, server chan<- *Server) error
AuthzRepo: authZRepo,
Storage: storage,
Commands: commands,
LogStore: actionsLogstoreSvc,
Router: router,
TLSConfig: tlsConfig,
Shutdown: shutdown,
@ -304,8 +301,6 @@ func startAPIs(
store static.Storage,
authZRepo authz_repo.Repository,
keys *encryptionKeys,
quotaQuerier logstore.QuotaQuerier,
usageReporter logstore.UsageReporter,
permissionCheck domain.PermissionCheck,
) error {
repo := struct {
@ -321,22 +316,22 @@ func startAPIs(
return err
}
accessStdoutEmitter, err := logstore.NewEmitter(ctx, clock, config.LogStore.Access.Stdout, stdout.NewStdoutEmitter())
accessStdoutEmitter, err := logstore.NewEmitter[*record.AccessLog](ctx, clock, &logstore.EmitterConfig{Enabled: config.LogStore.Access.Stdout.Enabled}, stdout.NewStdoutEmitter[*record.AccessLog]())
if err != nil {
return err
}
accessDBEmitter, err := logstore.NewEmitter(ctx, clock, config.LogStore.Access.Database, access.NewDatabaseLogStorage(dbClient))
accessDBEmitter, err := logstore.NewEmitter[*record.AccessLog](ctx, clock, &config.Quotas.Access.EmitterConfig, access.NewDatabaseLogStorage(dbClient, commands, queries))
if err != nil {
return err
}
accessSvc := logstore.New(quotaQuerier, usageReporter, accessDBEmitter, accessStdoutEmitter)
accessSvc := logstore.New[*record.AccessLog](queries, accessDBEmitter, accessStdoutEmitter)
exhaustedCookieHandler := http_util.NewCookieHandler(
http_util.WithUnsecure(),
http_util.WithNonHttpOnly(),
http_util.WithMaxAge(int(math.Floor(config.Quotas.Access.ExhaustedCookieMaxAge.Seconds()))),
)
limitingAccessInterceptor := middleware.NewAccessInterceptor(accessSvc, exhaustedCookieHandler, config.Quotas.Access)
limitingAccessInterceptor := middleware.NewAccessInterceptor(accessSvc, exhaustedCookieHandler, &config.Quotas.Access.AccessConfig)
apis, err := api.New(ctx, config.Port, router, queries, verifier, config.InternalAuthZ, tlsConfig, config.HTTP2HostHeader, config.HTTP1HostHeader, limitingAccessInterceptor)
if err != nil {
return fmt.Errorf("error creating api %w", err)
@ -399,13 +394,13 @@ func startAPIs(
}
apis.RegisterHandlerOnPrefix(openapi.HandlerPrefix, openAPIHandler)
oidcProvider, err := oidc.NewProvider(config.OIDC, login.DefaultLoggedOutPath, config.ExternalSecure, commands, queries, authRepo, keys.OIDC, keys.OIDCKey, eventstore, dbClient, userAgentInterceptor, instanceInterceptor.Handler, limitingAccessInterceptor.Handle)
oidcProvider, err := oidc.NewProvider(config.OIDC, login.DefaultLoggedOutPath, config.ExternalSecure, commands, queries, authRepo, keys.OIDC, keys.OIDCKey, eventstore, dbClient, userAgentInterceptor, instanceInterceptor.Handler, limitingAccessInterceptor)
if err != nil {
return fmt.Errorf("unable to start oidc provider: %w", err)
}
apis.RegisterHandlerPrefixes(oidcProvider.HttpHandler(), "/.well-known/openid-configuration", "/oidc/v1", "/oauth/v2")
samlProvider, err := saml.NewProvider(config.SAML, config.ExternalSecure, commands, queries, authRepo, keys.OIDC, keys.SAML, eventstore, dbClient, instanceInterceptor.Handler, userAgentInterceptor, limitingAccessInterceptor.Handle)
samlProvider, err := saml.NewProvider(config.SAML, config.ExternalSecure, commands, queries, authRepo, keys.OIDC, keys.SAML, eventstore, dbClient, instanceInterceptor.Handler, userAgentInterceptor, limitingAccessInterceptor)
if err != nil {
return fmt.Errorf("unable to start saml provider: %w", err)
}
@ -417,7 +412,7 @@ func startAPIs(
}
apis.RegisterHandlerOnPrefix(console.HandlerPrefix, c)
l, err := login.CreateLogin(config.Login, commands, queries, authRepo, store, console.HandlerPrefix+"/", op.AuthCallbackURL(oidcProvider), provider.AuthCallbackURL(samlProvider), config.ExternalSecure, userAgentInterceptor, op.NewIssuerInterceptor(oidcProvider.IssuerFromRequest).Handler, provider.NewIssuerInterceptor(samlProvider.IssuerFromRequest).Handler, instanceInterceptor.Handler, assetsCache.Handler, limitingAccessInterceptor.Handle, keys.User, keys.IDPConfig, keys.CSRFCookieKey)
l, err := login.CreateLogin(config.Login, commands, queries, authRepo, store, console.HandlerPrefix+"/", op.AuthCallbackURL(oidcProvider), provider.AuthCallbackURL(samlProvider), config.ExternalSecure, userAgentInterceptor, op.NewIssuerInterceptor(oidcProvider.IssuerFromRequest).Handler, provider.NewIssuerInterceptor(samlProvider.IssuerFromRequest).Handler, instanceInterceptor.Handler, assetsCache.Handler, limitingAccessInterceptor.WithoutLimiting().Handle, keys.User, keys.IDPConfig, keys.CSRFCookieKey)
if err != nil {
return fmt.Errorf("unable to start login: %w", err)
}
@ -438,7 +433,7 @@ func listen(ctx context.Context, router *mux.Router, port uint16, tlsConfig *tls
http2Server := &http2.Server{}
http1Server := &http.Server{Handler: h2c.NewHandler(router, http2Server), TLSConfig: tlsConfig}
lc := listenConfig()
lc := net.ListenConfig()
lis, err := lc.Listen(ctx, "tcp", fmt.Sprintf(":%d", port))
if err != nil {
return fmt.Errorf("tcp listener on %d failed: %w", port, err)

View File

@ -28,12 +28,7 @@ LogStore:
Database:
Enabled: true
Stdout:
Enabled: false
Quotas:
Access:
ExhaustedCookieKey: "zitadel.quota.limiting"
ExhaustedCookieMaxAge: "600s"
Enabled: true
Console:
InstanceManagementURL: "https://example.com/instances/{{.InstanceID}}"
@ -43,9 +38,25 @@ Projections:
NotificationsQuotas:
RequeueEvery: 1s
Quotas:
Access:
ExhaustedCookieKey: "zitadel.quota.limiting"
ExhaustedCookieMaxAge: "600s"
DefaultInstance:
LoginPolicy:
MfaInitSkipLifetime: "0"
Quotas:
Items:
- Unit: "actions.all.runs.seconds"
From: "2023-01-01T00:00:00Z"
ResetInterval: 5m
Amount: 20
Limit: false
Notifications:
- Percent: 100
Repeat: true
CallURL: "https://httpbin.org/post"
SystemAPIUsers:
- cypress:

View File

@ -1,324 +0,0 @@
import { addQuota, ensureQuotaIsAdded, ensureQuotaIsRemoved, removeQuota, Unit } from 'support/api/quota';
import { createHumanUser, ensureUserDoesntExist } from 'support/api/users';
import { Context } from 'support/commands';
import { ZITADELWebhookEvent } from 'support/types';
import { textChangeRangeIsUnchanged } from 'typescript';
beforeEach(() => {
cy.context().as('ctx');
});
describe('quotas', () => {
describe('management', () => {
describe('add one quota', () => {
it('should add a quota only once per unit', () => {
cy.get<Context>('@ctx').then((ctx) => {
addQuota(ctx, Unit.AuthenticatedRequests, true, 1);
addQuota(ctx, Unit.AuthenticatedRequests, true, 1, undefined, undefined, undefined, false).then((res) => {
expect(res.status).to.equal(409);
});
});
});
describe('add two quotas', () => {
it('should add a quota for each unit', () => {
cy.get<Context>('@ctx').then((ctx) => {
addQuota(ctx, Unit.AuthenticatedRequests, true, 1);
addQuota(ctx, Unit.ExecutionSeconds, true, 1);
});
});
});
});
describe('edit', () => {
describe('remove one quota', () => {
beforeEach(() => {
cy.get<Context>('@ctx').then((ctx) => {
ensureQuotaIsAdded(ctx, Unit.AuthenticatedRequests, true, 1);
});
});
it('should remove a quota only once per unit', () => {
cy.get<Context>('@ctx').then((ctx) => {
removeQuota(ctx, Unit.AuthenticatedRequests);
});
cy.get<Context>('@ctx').then((ctx) => {
removeQuota(ctx, Unit.AuthenticatedRequests, false).then((res) => {
expect(res.status).to.equal(404);
});
});
});
describe('remove two quotas', () => {
beforeEach(() => {
cy.get<Context>('@ctx').then((ctx) => {
ensureQuotaIsAdded(ctx, Unit.AuthenticatedRequests, true, 1);
ensureQuotaIsAdded(ctx, Unit.ExecutionSeconds, true, 1);
});
});
it('should remove a quota for each unit', () => {
cy.get<Context>('@ctx').then((ctx) => {
removeQuota(ctx, Unit.AuthenticatedRequests);
removeQuota(ctx, Unit.ExecutionSeconds);
});
});
});
});
});
});
describe('usage', () => {
beforeEach(() => {
cy.get<Context>('@ctx')
.then((ctx) => {
return [
`${ctx.api.oidcBaseURL}/userinfo`,
`${ctx.api.authBaseURL}/users/me`,
`${ctx.api.mgmtBaseURL}/iam`,
`${ctx.api.adminBaseURL}/instances/me`,
`${ctx.api.oauthBaseURL}/keys`,
`${ctx.api.samlBaseURL}/certificate`,
];
})
.as('authenticatedUrls');
});
describe('authenticated requests', () => {
const testUserName = 'shouldNotBeCreated';
beforeEach(() => {
cy.get<Array<string>>('@authenticatedUrls').then((urls) => {
cy.get<Context>('@ctx').then((ctx) => {
ensureUserDoesntExist(ctx.api, testUserName);
ensureQuotaIsAdded(ctx, Unit.AuthenticatedRequests, true, urls.length);
cy.task('runSQL', `TRUNCATE logstore.access;`);
});
});
});
it('only authenticated requests are limited', () => {
cy.get<Array<string>>('@authenticatedUrls').then((urls) => {
cy.get<Context>('@ctx').then((ctx) => {
const start = new Date();
urls.forEach((url) => {
cy.request({
url: url,
method: 'GET',
auth: {
bearer: ctx.api.token,
},
});
});
expectCookieDoesntExist();
const expiresMax = new Date();
expiresMax.setMinutes(expiresMax.getMinutes() + 20);
cy.request({
url: urls[1],
method: 'GET',
auth: {
bearer: ctx.api.token,
},
failOnStatusCode: false,
}).then((res) => {
expect(res.status).to.equal(429);
});
cy.getCookie('zitadel.quota.limiting').then((cookie) => {
expect(cookie.value).to.equal('true');
const cookieExpiry = new Date();
cookieExpiry.setTime(cookie.expiry * 1000);
expect(cookieExpiry).to.be.within(start, expiresMax);
});
createHumanUser(ctx.api, testUserName, false).then((res) => {
expect(res.status).to.equal(429);
});
// visit limited console
// cy.visit('/users/me');
// cy.contains('#authenticated-requests-exhausted-dialog button', 'Continue').click();
// const upgradeInstancePage = `https://example.com/instances/${ctx.instanceId}`;
// cy.origin(upgradeInstancePage, { args: { upgradeInstancePage } }, ({ upgradeInstancePage }) => {
// cy.location('href').should('equal', upgradeInstancePage);
// });
// upgrade instance
ensureQuotaIsRemoved(ctx, Unit.AuthenticatedRequests);
// visit upgraded console again
cy.visit('/users/me');
cy.get('[data-e2e="top-view-title"]');
expectCookieDoesntExist();
createHumanUser(ctx.api, testUserName);
expectCookieDoesntExist();
});
});
});
});
describe.skip('notifications', () => {
const callURL = `http://${Cypress.env('WEBHOOK_HANDLER_HOST')}:${Cypress.env('WEBHOOK_HANDLER_PORT')}/do_something`;
beforeEach(() => cy.task('resetWebhookEvents'));
const amount = 100;
const percent = 10;
const usage = 35;
describe('without repetition', () => {
beforeEach(() => {
cy.get<Context>('@ctx').then((ctx) => {
ensureQuotaIsAdded(ctx, Unit.AuthenticatedRequests, false, amount, [
{
callUrl: callURL,
percent: percent,
repeat: false,
},
]);
cy.task('runSQL', `TRUNCATE logstore.access;`);
});
});
it('fires at least once with the expected payload', () => {
cy.get<Array<string>>('@authenticatedUrls').then((urls) => {
cy.get<Context>('@ctx').then((ctx) => {
for (let i = 0; i < usage; i++) {
cy.request({
url: urls[0],
method: 'GET',
auth: {
bearer: ctx.api.token,
},
});
}
});
cy.waitUntil(
() =>
cy.task<Array<ZITADELWebhookEvent>>('handledWebhookEvents').then((events) => {
if (events.length < 1) {
return false;
}
return Cypress._.matches(<ZITADELWebhookEvent>{
sentStatus: 200,
payload: {
callURL: callURL,
threshold: percent,
unit: 1,
usage: percent,
},
})(events[0]);
}),
{ timeout: 60_000 },
);
});
});
it('fires until the webhook returns a successful message', () => {
cy.task('failWebhookEvents', 8);
cy.get<Array<string>>('@authenticatedUrls').then((urls) => {
cy.get<Context>('@ctx').then((ctx) => {
for (let i = 0; i < usage; i++) {
cy.request({
url: urls[0],
method: 'GET',
auth: {
bearer: ctx.api.token,
},
});
}
});
cy.waitUntil(
() =>
cy.task<Array<ZITADELWebhookEvent>>('handledWebhookEvents').then((events) => {
if (events.length != 9) {
return false;
}
return events.reduce<boolean>((a, b, i) => {
return !a
? a
: i < 8
? Cypress._.matches(<ZITADELWebhookEvent>{
sentStatus: 500,
payload: {
callURL: callURL,
threshold: percent,
unit: 1,
usage: percent,
},
})(b)
: Cypress._.matches(<ZITADELWebhookEvent>{
sentStatus: 200,
payload: {
callURL: callURL,
threshold: percent,
unit: 1,
usage: percent,
},
})(b);
}, true);
}),
{ timeout: 60_000 },
);
});
});
});
describe('with repetition', () => {
beforeEach(() => {
cy.get<Array<string>>('@authenticatedUrls').then((urls) => {
cy.get<Context>('@ctx').then((ctx) => {
ensureQuotaIsAdded(ctx, Unit.AuthenticatedRequests, false, amount, [
{
callUrl: callURL,
percent: percent,
repeat: true,
},
]);
cy.task('runSQL', `TRUNCATE logstore.access;`);
});
});
});
it('fires repeatedly with the expected payloads', () => {
cy.get<Array<string>>('@authenticatedUrls').then((urls) => {
cy.get<Context>('@ctx').then((ctx) => {
for (let i = 0; i < usage; i++) {
cy.request({
url: urls[0],
method: 'GET',
auth: {
bearer: ctx.api.token,
},
});
}
});
});
cy.waitUntil(
() =>
cy.task<Array<ZITADELWebhookEvent>>('handledWebhookEvents').then((events) => {
let foundExpected = 0;
for (let i = 0; i < events.length; i++) {
for (let expect = 10; expect <= 30; expect += 10) {
if (
Cypress._.matches(<ZITADELWebhookEvent>{
sentStatus: 200,
payload: {
callURL: callURL,
threshold: expect,
unit: 1,
usage: expect,
},
})(events[i])
) {
foundExpected++;
}
}
}
return foundExpected >= 3;
}),
{ timeout: 60_000 },
);
});
});
});
});
});
function expectCookieDoesntExist() {
cy.getCookie('zitadel.quota.limiting').then((cookie) => {
expect(cookie).to.be.null;
});
}

View File

@ -7,11 +7,13 @@ import (
"time"
"github.com/dop251/goja"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/logstore/record"
)
func TestRun(t *testing.T) {
SetLogstoreService(logstore.New(nil, nil, nil))
SetLogstoreService(logstore.New[*record.ExecutionLog](nil, nil))
type args struct {
timeout time.Duration
api apiFields

View File

@ -5,11 +5,13 @@ import (
"testing"
"github.com/dop251/goja"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/logstore/record"
)
func TestSetFields(t *testing.T) {
SetLogstoreService(logstore.New(nil, nil, nil))
SetLogstoreService(logstore.New[*record.ExecutionLog](nil, nil))
primitveFn := func(a string) { fmt.Println(a) }
complexFn := func(*FieldConfig) interface{} {
return primitveFn

View File

@ -10,12 +10,14 @@ import (
"testing"
"github.com/dop251/goja"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/logstore/record"
)
func Test_isHostBlocked(t *testing.T) {
SetLogstoreService(logstore.New(nil, nil, nil))
SetLogstoreService(logstore.New[*record.ExecutionLog](nil, nil))
var denyList = []AddressChecker{
mustNewIPChecker(t, "192.168.5.0/24"),
mustNewIPChecker(t, "127.0.0.1"),

View File

@ -10,15 +10,15 @@ import (
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/logstore/emitters/execution"
"github.com/zitadel/zitadel/internal/logstore/record"
)
var (
logstoreService *logstore.Service
logstoreService *logstore.Service[*record.ExecutionLog]
_ console.Printer = (*logger)(nil)
)
func SetLogstoreService(svc *logstore.Service) {
func SetLogstoreService(svc *logstore.Service[*record.ExecutionLog]) {
logstoreService = svc
}
@ -55,19 +55,16 @@ func (l *logger) log(msg string, level logrus.Level, last bool) {
if l.started.IsZero() {
l.started = ts
}
record := &execution.Record{
r := &record.ExecutionLog{
LogDate: ts,
InstanceID: l.instanceID,
Message: msg,
LogLevel: level,
}
if last {
record.Took = ts.Sub(l.started)
r.Took = ts.Sub(l.started)
}
logstoreService.Handle(l.ctx, record)
logstoreService.Handle(l.ctx, r)
}
func withLogger(ctx context.Context) Option {

View File

@ -10,11 +10,11 @@ import (
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/logstore/emitters/access"
"github.com/zitadel/zitadel/internal/logstore/record"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
func AccessStorageInterceptor(svc *logstore.Service) grpc.UnaryServerInterceptor {
func AccessStorageInterceptor(svc *logstore.Service[*record.AccessLog]) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) {
if !svc.Enabled() {
return handler(ctx, req)
@ -36,9 +36,9 @@ func AccessStorageInterceptor(svc *logstore.Service) grpc.UnaryServerInterceptor
resMd, _ := metadata.FromOutgoingContext(ctx)
instance := authz.GetInstance(ctx)
record := &access.Record{
r := &record.AccessLog{
LogDate: time.Now(),
Protocol: access.GRPC,
Protocol: record.GRPC,
RequestURL: info.FullMethod,
ResponseStatus: respStatus,
RequestHeaders: reqMd,
@ -49,7 +49,7 @@ func AccessStorageInterceptor(svc *logstore.Service) grpc.UnaryServerInterceptor
RequestedHost: instance.RequestedHost(),
}
svc.Handle(interceptorCtx, record)
svc.Handle(interceptorCtx, r)
return resp, handlerErr
}
}

View File

@ -9,19 +9,16 @@ import (
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/logstore/record"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
func QuotaExhaustedInterceptor(svc *logstore.Service, ignoreService ...string) grpc.UnaryServerInterceptor {
prunedIgnoredServices := make([]string, len(ignoreService))
func QuotaExhaustedInterceptor(svc *logstore.Service[*record.AccessLog], ignoreService ...string) grpc.UnaryServerInterceptor {
for idx, service := range ignoreService {
if !strings.HasPrefix(service, "/") {
service = "/" + service
ignoreService[idx] = "/" + service
}
prunedIgnoredServices[idx] = service
}
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) {
if !svc.Enabled() {
return handler(ctx, req)
@ -29,7 +26,13 @@ func QuotaExhaustedInterceptor(svc *logstore.Service, ignoreService ...string) g
interceptorCtx, span := tracing.NewServerInterceptorSpan(ctx)
defer func() { span.EndWithError(err) }()
for _, service := range prunedIgnoredServices {
// The auth interceptor will ensure that only authorized or public requests are allowed.
// So if there's no authorization context, we don't need to check for limitation
if authz.GetCtxData(ctx).IsZero() {
return handler(ctx, req)
}
for _, service := range ignoreService {
if strings.HasPrefix(info.FullMethod, service) {
return handler(ctx, req)
}

View File

@ -12,6 +12,7 @@ import (
grpc_api "github.com/zitadel/zitadel/internal/api/grpc"
"github.com/zitadel/zitadel/internal/api/grpc/server/middleware"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/logstore/record"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/metrics"
system_pb "github.com/zitadel/zitadel/pkg/grpc/system"
@ -39,7 +40,7 @@ func CreateServer(
queries *query.Queries,
hostHeaderName string,
tlsConfig *tls.Config,
accessSvc *logstore.Service,
accessSvc *logstore.Service[*record.AccessLog],
) *grpc.Server {
metricTypes := []metrics.MetricType{metrics.MetricTypeTotalCount, metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode}
serverOptions := []grpc.ServerOption{
@ -49,14 +50,14 @@ func CreateServer(
middleware.DefaultTracingServer(),
middleware.MetricsHandler(metricTypes, grpc_api.Probes...),
middleware.NoCacheInterceptor(),
middleware.ErrorHandler(),
middleware.InstanceInterceptor(queries, hostHeaderName, system_pb.SystemService_ServiceDesc.ServiceName, healthpb.Health_ServiceDesc.ServiceName),
middleware.AccessStorageInterceptor(accessSvc),
middleware.ErrorHandler(),
middleware.AuthorizationInterceptor(verifier, authConfig),
middleware.QuotaExhaustedInterceptor(accessSvc, system_pb.SystemService_ServiceDesc.ServiceName),
middleware.TranslationHandler(),
middleware.ValidationHandler(),
middleware.ServiceHandler(),
middleware.QuotaExhaustedInterceptor(accessSvc, system_pb.SystemService_ServiceDesc.ServiceName),
),
),
}

View File

@ -0,0 +1,203 @@
//go:build integration
package system_test
import (
"bytes"
"encoding/json"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/zitadel/zitadel/internal/integration"
"github.com/zitadel/zitadel/internal/repository/quota"
"github.com/zitadel/zitadel/pkg/grpc/admin"
quota_pb "github.com/zitadel/zitadel/pkg/grpc/quota"
"github.com/zitadel/zitadel/pkg/grpc/system"
)
var callURL = "http://localhost:" + integration.PortQuotaServer
func TestServer_QuotaNotification_Limit(t *testing.T) {
_, instanceID, iamOwnerCtx := Tester.UseIsolatedInstance(CTX, SystemCTX)
amount := 10
percent := 50
percentAmount := amount * percent / 100
_, err := Tester.Client.System.AddQuota(SystemCTX, &system.AddQuotaRequest{
InstanceId: instanceID,
Unit: quota_pb.Unit_UNIT_REQUESTS_ALL_AUTHENTICATED,
From: timestamppb.Now(),
ResetInterval: durationpb.New(time.Minute * 5),
Amount: uint64(amount),
Limit: true,
Notifications: []*quota_pb.Notification{
{
Percent: uint32(percent),
Repeat: true,
CallUrl: callURL,
},
{
Percent: 100,
Repeat: true,
CallUrl: callURL,
},
},
})
require.NoError(t, err)
for i := 0; i < percentAmount; i++ {
_, err := Tester.Client.Admin.GetDefaultOrg(iamOwnerCtx, &admin.GetDefaultOrgRequest{})
require.NoErrorf(t, err, "error in %d call of %d", i, percentAmount)
}
awaitNotification(t, Tester.QuotaNotificationChan, quota.RequestsAllAuthenticated, percent)
for i := 0; i < (amount - percentAmount); i++ {
_, err := Tester.Client.Admin.GetDefaultOrg(iamOwnerCtx, &admin.GetDefaultOrgRequest{})
require.NoErrorf(t, err, "error in %d call of %d", i, percentAmount)
}
awaitNotification(t, Tester.QuotaNotificationChan, quota.RequestsAllAuthenticated, 100)
_, limitErr := Tester.Client.Admin.GetDefaultOrg(iamOwnerCtx, &admin.GetDefaultOrgRequest{})
require.Error(t, limitErr)
}
func TestServer_QuotaNotification_NoLimit(t *testing.T) {
_, instanceID, iamOwnerCtx := Tester.UseIsolatedInstance(CTX, SystemCTX)
amount := 10
percent := 50
percentAmount := amount * percent / 100
_, err := Tester.Client.System.AddQuota(SystemCTX, &system.AddQuotaRequest{
InstanceId: instanceID,
Unit: quota_pb.Unit_UNIT_REQUESTS_ALL_AUTHENTICATED,
From: timestamppb.Now(),
ResetInterval: durationpb.New(time.Minute * 5),
Amount: uint64(amount),
Limit: false,
Notifications: []*quota_pb.Notification{
{
Percent: uint32(percent),
Repeat: false,
CallUrl: callURL,
},
{
Percent: 100,
Repeat: true,
CallUrl: callURL,
},
},
})
require.NoError(t, err)
for i := 0; i < percentAmount; i++ {
_, err := Tester.Client.Admin.GetDefaultOrg(iamOwnerCtx, &admin.GetDefaultOrgRequest{})
require.NoErrorf(t, err, "error in %d call of %d", i, percentAmount)
}
awaitNotification(t, Tester.QuotaNotificationChan, quota.RequestsAllAuthenticated, percent)
for i := 0; i < (amount - percentAmount); i++ {
_, err := Tester.Client.Admin.GetDefaultOrg(iamOwnerCtx, &admin.GetDefaultOrgRequest{})
require.NoErrorf(t, err, "error in %d call of %d", i, percentAmount)
}
awaitNotification(t, Tester.QuotaNotificationChan, quota.RequestsAllAuthenticated, 100)
for i := 0; i < amount; i++ {
_, err := Tester.Client.Admin.GetDefaultOrg(iamOwnerCtx, &admin.GetDefaultOrgRequest{})
require.NoErrorf(t, err, "error in %d call of %d", i, percentAmount)
}
awaitNotification(t, Tester.QuotaNotificationChan, quota.RequestsAllAuthenticated, 200)
_, limitErr := Tester.Client.Admin.GetDefaultOrg(iamOwnerCtx, &admin.GetDefaultOrgRequest{})
require.NoError(t, limitErr)
}
func awaitNotification(t *testing.T, bodies chan []byte, unit quota.Unit, percent int) {
for {
select {
case body := <-bodies:
plain := new(bytes.Buffer)
if err := json.Indent(plain, body, "", " "); err != nil {
t.Fatal(err)
}
t.Log("received notificationDueEvent", plain.String())
event := struct {
Unit quota.Unit `json:"unit"`
ID string `json:"id"`
CallURL string `json:"callURL"`
PeriodStart time.Time `json:"periodStart"`
Threshold uint16 `json:"threshold"`
Usage uint64 `json:"usage"`
}{}
if err := json.Unmarshal(body, &event); err != nil {
t.Error(err)
}
if event.ID == "" {
continue
}
if event.Unit == unit && event.Threshold == uint16(percent) {
return
}
case <-time.After(60 * time.Second):
t.Fatalf("timed out waiting for unit %s and percent %d", strconv.Itoa(int(unit)), percent)
}
}
}
func TestServer_AddAndRemoveQuota(t *testing.T) {
_, instanceID, _ := Tester.UseIsolatedInstance(CTX, SystemCTX)
got, err := Tester.Client.System.AddQuota(SystemCTX, &system.AddQuotaRequest{
InstanceId: instanceID,
Unit: quota_pb.Unit_UNIT_REQUESTS_ALL_AUTHENTICATED,
From: timestamppb.Now(),
ResetInterval: durationpb.New(time.Minute),
Amount: 10,
Limit: true,
Notifications: []*quota_pb.Notification{
{
Percent: 20,
Repeat: true,
CallUrl: callURL,
},
},
})
require.NoError(t, err)
require.Equal(t, got.Details.ResourceOwner, instanceID)
gotAlreadyExisting, errAlreadyExisting := Tester.Client.System.AddQuota(SystemCTX, &system.AddQuotaRequest{
InstanceId: instanceID,
Unit: quota_pb.Unit_UNIT_REQUESTS_ALL_AUTHENTICATED,
From: timestamppb.Now(),
ResetInterval: durationpb.New(time.Minute),
Amount: 10,
Limit: true,
Notifications: []*quota_pb.Notification{
{
Percent: 20,
Repeat: true,
CallUrl: callURL,
},
},
})
require.Error(t, errAlreadyExisting)
require.Nil(t, gotAlreadyExisting)
gotRemove, errRemove := Tester.Client.System.RemoveQuota(SystemCTX, &system.RemoveQuotaRequest{
InstanceId: instanceID,
Unit: quota_pb.Unit_UNIT_REQUESTS_ALL_AUTHENTICATED,
})
require.NoError(t, errRemove)
require.Equal(t, gotRemove.Details.ResourceOwner, instanceID)
gotRemoveAlready, errRemoveAlready := Tester.Client.System.RemoveQuota(SystemCTX, &system.RemoveQuotaRequest{
InstanceId: instanceID,
Unit: quota_pb.Unit_UNIT_REQUESTS_ALL_AUTHENTICATED,
})
require.Error(t, errRemoveAlready)
require.Nil(t, gotRemoveAlready)
}

View File

@ -0,0 +1,32 @@
//go:build integration
package system_test
import (
"context"
"os"
"testing"
"time"
"github.com/zitadel/zitadel/internal/integration"
)
var (
CTX context.Context
SystemCTX context.Context
Tester *integration.Tester
)
func TestMain(m *testing.M) {
os.Exit(func() int {
ctx, _, cancel := integration.Contexts(5 * time.Minute)
defer cancel()
CTX = ctx
Tester = integration.NewTester(ctx)
defer Tester.Done()
SystemCTX = Tester.WithAuthorization(ctx, integration.SystemUser)
return m.Run()
}())
}

View File

@ -14,12 +14,12 @@ import (
"github.com/zitadel/zitadel/internal/api/grpc/server/middleware"
http_utils "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/logstore/emitters/access"
"github.com/zitadel/zitadel/internal/logstore/record"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
type AccessInterceptor struct {
svc *logstore.Service
svc *logstore.Service[*record.AccessLog]
cookieHandler *http_utils.CookieHandler
limitConfig *AccessConfig
storeOnly bool
@ -33,7 +33,7 @@ type AccessConfig struct {
// NewAccessInterceptor intercepts all requests and stores them to the logstore.
// If storeOnly is false, it also checks if requests are exhausted.
// If requests are exhausted, it also returns http.StatusTooManyRequests and sets a cookie
func NewAccessInterceptor(svc *logstore.Service, cookieHandler *http_utils.CookieHandler, cookieConfig *AccessConfig) *AccessInterceptor {
func NewAccessInterceptor(svc *logstore.Service[*record.AccessLog], cookieHandler *http_utils.CookieHandler, cookieConfig *AccessConfig) *AccessInterceptor {
return &AccessInterceptor{
svc: svc,
cookieHandler: cookieHandler,
@ -50,7 +50,7 @@ func (a *AccessInterceptor) WithoutLimiting() *AccessInterceptor {
}
}
func (a *AccessInterceptor) AccessService() *logstore.Service {
func (a *AccessInterceptor) AccessService() *logstore.Service[*record.AccessLog] {
return a.svc
}
@ -81,14 +81,32 @@ func (a *AccessInterceptor) DeleteExhaustedCookie(writer http.ResponseWriter) {
a.cookieHandler.DeleteCookie(writer, a.limitConfig.ExhaustedCookieKey)
}
func (a *AccessInterceptor) HandleIgnorePathPrefixes(ignoredPathPrefixes []string) func(next http.Handler) http.Handler {
return a.handle(ignoredPathPrefixes...)
}
func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
return a.handle()(next)
}
func (a *AccessInterceptor) handle(ignoredPathPrefixes ...string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
if !a.svc.Enabled() {
return next
}
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
ctx := request.Context()
tracingCtx, checkSpan := tracing.NewNamedSpan(ctx, "checkAccess")
tracingCtx, checkSpan := tracing.NewNamedSpan(ctx, "checkAccessQuota")
wrappedWriter := &statusRecorder{ResponseWriter: writer, status: 0}
for _, ignoredPathPrefix := range ignoredPathPrefixes {
if !strings.HasPrefix(request.RequestURI, ignoredPathPrefix) {
continue
}
checkSpan.End()
next.ServeHTTP(wrappedWriter, request)
a.writeLog(tracingCtx, wrappedWriter, writer, request, true)
return
}
limited := a.Limit(tracingCtx)
checkSpan.End()
if limited {
@ -101,17 +119,23 @@ func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
if !limited {
next.ServeHTTP(wrappedWriter, request)
}
tracingCtx, writeSpan := tracing.NewNamedSpan(tracingCtx, "writeAccess")
a.writeLog(tracingCtx, wrappedWriter, writer, request, a.storeOnly)
})
}
}
func (a *AccessInterceptor) writeLog(ctx context.Context, wrappedWriter *statusRecorder, writer http.ResponseWriter, request *http.Request, notCountable bool) {
ctx, writeSpan := tracing.NewNamedSpan(ctx, "writeAccess")
defer writeSpan.End()
requestURL := request.RequestURI
unescapedURL, err := url.QueryUnescape(requestURL)
if err != nil {
logging.WithError(err).WithField("url", requestURL).Warning("failed to unescape request url")
}
instance := authz.GetInstance(tracingCtx)
a.svc.Handle(tracingCtx, &access.Record{
instance := authz.GetInstance(ctx)
a.svc.Handle(ctx, &record.AccessLog{
LogDate: time.Now(),
Protocol: access.HTTP,
Protocol: record.HTTP,
RequestURL: unescapedURL,
ResponseStatus: uint32(wrappedWriter.status),
RequestHeaders: request.Header,
@ -120,7 +144,7 @@ func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
ProjectID: instance.ProjectID(),
RequestedDomain: instance.RequestedDomain(),
RequestedHost: instance.RequestedHost(),
})
NotCountable: notCountable,
})
}

View File

@ -7,6 +7,7 @@ import (
"time"
"github.com/rakyll/statik/fs"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/oidc/v2/pkg/op"
"golang.org/x/text/language"
@ -79,13 +80,32 @@ type OPStorage struct {
assetAPIPrefix func(ctx context.Context) string
}
func NewProvider(config Config, defaultLogoutRedirectURI string, externalSecure bool, command *command.Commands, query *query.Queries, repo repository.Repository, encryptionAlg crypto.EncryptionAlgorithm, cryptoKey []byte, es *eventstore.Eventstore, projections *database.DB, userAgentCookie, instanceHandler, accessHandler func(http.Handler) http.Handler) (op.OpenIDProvider, error) {
func NewProvider(
config Config,
defaultLogoutRedirectURI string,
externalSecure bool,
command *command.Commands,
query *query.Queries,
repo repository.Repository,
encryptionAlg crypto.EncryptionAlgorithm,
cryptoKey []byte,
es *eventstore.Eventstore,
projections *database.DB,
userAgentCookie, instanceHandler func(http.Handler) http.Handler,
accessHandler *middleware.AccessInterceptor,
) (op.OpenIDProvider, error) {
opConfig, err := createOPConfig(config, defaultLogoutRedirectURI, cryptoKey)
if err != nil {
return nil, caos_errs.ThrowInternal(err, "OIDC-EGrqd", "cannot create op config: %w")
}
storage := newStorage(config, command, query, repo, encryptionAlg, es, projections, externalSecure)
options, err := createOptions(config, externalSecure, userAgentCookie, instanceHandler, accessHandler)
options, err := createOptions(
config,
externalSecure,
userAgentCookie,
instanceHandler,
accessHandler.HandleIgnorePathPrefixes(ignoredQuotaLimitEndpoint(config.CustomEndpoints)),
)
if err != nil {
return nil, caos_errs.ThrowInternal(err, "OIDC-D3gq1", "cannot create options: %w")
}
@ -101,6 +121,21 @@ func NewProvider(config Config, defaultLogoutRedirectURI string, externalSecure
return provider, nil
}
func ignoredQuotaLimitEndpoint(endpoints *EndpointConfig) []string {
authURL := op.DefaultEndpoints.Authorization.Relative()
keysURL := op.DefaultEndpoints.JwksURI.Relative()
if endpoints == nil {
return []string{oidc.DiscoveryEndpoint, authURL, keysURL}
}
if endpoints.Auth != nil && endpoints.Auth.Path != "" {
authURL = endpoints.Auth.Path
}
if endpoints.Keys != nil && endpoints.Keys.Path != "" {
keysURL = endpoints.Keys.Path
}
return []string{oidc.DiscoveryEndpoint, authURL, keysURL}
}
func createOPConfig(config Config, defaultLogoutRedirectURI string, cryptoKey []byte) (*op.Config, error) {
supportedLanguages, err := getSupportedLanguages()
if err != nil {

View File

@ -38,8 +38,8 @@ func NewProvider(
es *eventstore.Eventstore,
projections *database.DB,
instanceHandler,
userAgentCookie,
accessHandler func(http.Handler) http.Handler,
userAgentCookie func(http.Handler) http.Handler,
accessHandler *middleware.AccessInterceptor,
) (*provider.Provider, error) {
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
@ -63,7 +63,7 @@ func NewProvider(
middleware.NoCacheInterceptor().Handler,
instanceHandler,
userAgentCookie,
accessHandler,
accessHandler.HandleIgnorePathPrefixes(ignoredQuotaLimitEndpoint(conf.ProviderConfig)),
http_utils.CopyHeadersToContext,
),
provider.WithCustomTimeFormat("2006-01-02T15:04:05.999Z"),
@ -100,3 +100,22 @@ func newStorage(
defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID),
}, nil
}
func ignoredQuotaLimitEndpoint(config *provider.Config) []string {
metadataEndpoint := HandlerPrefix + provider.DefaultMetadataEndpoint
certificateEndpoint := HandlerPrefix + provider.DefaultCertificateEndpoint
ssoEndpoint := HandlerPrefix + provider.DefaultSingleSignOnEndpoint
if config.MetadataConfig != nil && config.MetadataConfig.Path != "" {
metadataEndpoint = HandlerPrefix + config.MetadataConfig.Path
}
if config.IDPConfig == nil || config.IDPConfig.Endpoints == nil {
return []string{metadataEndpoint, certificateEndpoint, ssoEndpoint}
}
if config.IDPConfig.Endpoints.Certificate != nil && config.IDPConfig.Endpoints.Certificate.Relative() != "" {
certificateEndpoint = HandlerPrefix + config.IDPConfig.Endpoints.Certificate.Relative()
}
if config.IDPConfig.Endpoints.SingleSignOn != nil && config.IDPConfig.Endpoints.SingleSignOn.Relative() != "" {
ssoEndpoint = HandlerPrefix + config.IDPConfig.Endpoints.SingleSignOn.Relative()
}
return []string{metadataEndpoint, certificateEndpoint, ssoEndpoint}
}

View File

@ -283,10 +283,7 @@ func (c *Commands) SetUpInstance(ctx context.Context, setup *InstanceSetup) (str
if err != nil {
return "", "", nil, nil, err
}
quotaAggregate := quota.NewAggregate(quotaId, instanceID, instanceID)
validations = append(validations, c.AddQuotaCommand(quotaAggregate, q))
validations = append(validations, c.AddQuotaCommand(quota.NewAggregate(quotaId, instanceID), q))
}
}

View File

@ -26,6 +26,7 @@ import (
"github.com/zitadel/zitadel/internal/repository/oidcsession"
"github.com/zitadel/zitadel/internal/repository/org"
proj_repo "github.com/zitadel/zitadel/internal/repository/project"
quota_repo "github.com/zitadel/zitadel/internal/repository/quota"
"github.com/zitadel/zitadel/internal/repository/session"
usr_repo "github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/repository/usergrant"
@ -50,6 +51,7 @@ func eventstoreExpect(t *testing.T, expects ...expect) *eventstore.Eventstore {
idpintent.RegisterEventMappers(es)
authrequest.RegisterEventMappers(es)
oidcsession.RegisterEventMappers(es)
quota_repo.RegisterEventMappers(es)
return es
}

View File

@ -21,8 +21,8 @@ const (
QuotaActionsAllRunsSeconds QuotaUnit = "actions.all.runs.seconds"
)
func (q *QuotaUnit) Enum() quota.Unit {
switch *q {
func (q QuotaUnit) Enum() quota.Unit {
switch q {
case QuotaRequestsAllAuthenticated:
return quota.RequestsAllAuthenticated
case QuotaActionsAllRunsSeconds:
@ -46,14 +46,11 @@ func (c *Commands) AddQuota(
if wm.active {
return nil, errors.ThrowAlreadyExists(nil, "COMMAND-WDfFf", "Errors.Quota.AlreadyExists")
}
aggregateId, err := c.idGenerator.Next()
if err != nil {
return nil, err
}
aggregate := quota.NewAggregate(aggregateId, instanceId, instanceId)
cmds, err := preparation.PrepareCommands(ctx, c.eventstore.Filter, c.AddQuotaCommand(aggregate, q))
cmds, err := preparation.PrepareCommands(ctx, c.eventstore.Filter, c.AddQuotaCommand(quota.NewAggregate(aggregateId, instanceId), q))
if err != nil {
return nil, err
}
@ -80,7 +77,7 @@ func (c *Commands) RemoveQuota(ctx context.Context, unit QuotaUnit) (*domain.Obj
return nil, errors.ThrowNotFound(nil, "COMMAND-WDfFf", "Errors.Quota.NotFound")
}
aggregate := quota.NewAggregate(wm.AggregateID, instanceId, instanceId)
aggregate := quota.NewAggregate(wm.AggregateID, instanceId)
events := []eventstore.Command{
quota.NewRemovedEvent(ctx, &aggregate.Aggregate, unit.Enum()),
@ -109,6 +106,22 @@ type QuotaNotification struct {
type QuotaNotifications []*QuotaNotification
func (q *QuotaNotification) validate() error {
u, err := url.Parse(q.CallURL)
if err != nil {
return errors.ThrowInvalidArgument(err, "QUOTA-bZ0Fj", "Errors.Quota.Invalid.CallURL")
}
if !u.IsAbs() || u.Host == "" {
return errors.ThrowInvalidArgument(nil, "QUOTA-HAYmN", "Errors.Quota.Invalid.CallURL")
}
if q.Percent < 1 {
return errors.ThrowInvalidArgument(nil, "QUOTA-pBfjq", "Errors.Quota.Invalid.Percent")
}
return nil
}
func (q *QuotaNotifications) toAddedEventNotifications(idGenerator id.Generator) ([]*quota.AddedEventNotification, error) {
if q == nil {
return nil, nil
@ -144,17 +157,8 @@ type AddQuota struct {
func (q *AddQuota) validate() error {
for _, notification := range q.Notifications {
u, err := url.Parse(notification.CallURL)
if err != nil {
return errors.ThrowInvalidArgument(err, "QUOTA-bZ0Fj", "Errors.Quota.Invalid.CallURL")
}
if !u.IsAbs() || u.Host == "" {
return errors.ThrowInvalidArgument(nil, "QUOTA-HAYmN", "Errors.Quota.Invalid.CallURL")
}
if notification.Percent < 1 {
return errors.ThrowInvalidArgument(nil, "QUOTA-pBfjq", "Errors.Quota.Invalid.Percent")
if err := notification.validate(); err != nil {
return err
}
}
@ -169,11 +173,6 @@ func (q *AddQuota) validate() error {
if q.ResetInterval < time.Minute {
return errors.ThrowInvalidArgument(nil, "QUOTA-R5otd", "Errors.Quota.Invalid.ResetInterval")
}
if !q.Limit && len(q.Notifications) == 0 {
return errors.ThrowInvalidArgument(nil, "QUOTA-4Nv68", "Errors.Quota.Invalid.Noop")
}
return nil
}

View File

@ -41,6 +41,7 @@ func (wm *quotaWriteModel) Reduce() error {
switch e := event.(type) {
case *quota.AddedEvent:
wm.AggregateID = e.Aggregate().ID
wm.ChangeDate = e.CreationDate()
wm.active = true
case *quota.RemovedEvent:
wm.AggregateID = e.Aggregate().ID

View File

@ -5,16 +5,47 @@ import (
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/quota"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
// ReportQuotaUsage writes a slice of *quota.NotificationDueEvent directly to the eventstore
func (c *Commands) ReportQuotaUsage(ctx context.Context, dueNotifications []*quota.NotificationDueEvent) error {
cmds := make([]eventstore.Command, len(dueNotifications))
for idx, notification := range dueNotifications {
cmds[idx] = notification
func (c *Commands) ReportQuotaUsage(ctx context.Context, dueNotifications []*quota.NotificationDueEvent) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
cmds := make([]eventstore.Command, 0, len(dueNotifications))
for _, notification := range dueNotifications {
ctxFilter, spanFilter := tracing.NewNamedSpan(ctx, "filterNotificationDueEvents")
events, errFilter := c.eventstore.Filter(
ctxFilter,
eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
InstanceID(notification.Aggregate().InstanceID).
AddQuery().
AggregateTypes(quota.AggregateType).
AggregateIDs(notification.Aggregate().ID).
EventTypes(quota.NotificationDueEventType).
EventData(map[string]interface{}{
"id": notification.ID,
"periodStart": notification.PeriodStart,
"threshold": notification.Threshold,
}).Builder(),
)
spanFilter.EndWithError(errFilter)
if errFilter != nil {
return errFilter
}
_, err := c.eventstore.Push(ctx, cmds...)
return err
if len(events) > 0 {
continue
}
cmds = append(cmds, notification)
}
if len(cmds) == 0 {
return nil
}
ctxPush, spanPush := tracing.NewNamedSpan(ctx, "pushNotificationDueEvents")
_, errPush := c.eventstore.Push(ctxPush, cmds...)
spanPush.EndWithError(errPush)
return errPush
}
func (c *Commands) UsageNotificationSent(ctx context.Context, dueEvent *quota.NotificationDueEvent) error {

View File

@ -0,0 +1,307 @@
package command
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/repository"
"github.com/zitadel/zitadel/internal/id"
id_mock "github.com/zitadel/zitadel/internal/id/mock"
"github.com/zitadel/zitadel/internal/repository/quota"
)
func TestQuotaReport_ReportQuotaUsage(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
}
type args struct {
ctx context.Context
dueNotifications []*quota.NotificationDueEvent
}
type res struct {
err func(error) bool
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
name: "no due events",
fields: fields{
eventstore: eventstoreExpect(
t,
),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
},
res: res{},
},
{
name: "due event already reported",
fields: fields{
eventstore: eventstoreExpect(
t,
expectFilter(
eventFromEventPusherWithInstanceID(
"INSTANCE",
quota.NewNotificationDueEvent(context.Background(),
&quota.NewAggregate("quota1", "INSTANCE").Aggregate,
QuotaRequestsAllAuthenticated.Enum(),
"id",
"url",
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
1000,
200,
),
),
),
),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
dueNotifications: []*quota.NotificationDueEvent{
{
Unit: QuotaRequestsAllAuthenticated.Enum(),
ID: "id",
CallURL: "url",
PeriodStart: time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
Threshold: 1000,
Usage: 250,
},
},
},
res: res{},
},
{
name: "due event not reported",
fields: fields{
eventstore: eventstoreExpect(
t,
expectFilter(),
expectPush(
[]*repository.Event{
eventFromEventPusherWithInstanceID(
"INSTANCE",
quota.NewNotificationDueEvent(context.Background(),
&quota.NewAggregate("quota1", "INSTANCE").Aggregate,
QuotaRequestsAllAuthenticated.Enum(),
"id",
"url",
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
1000,
250,
),
),
},
),
),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
dueNotifications: []*quota.NotificationDueEvent{
quota.NewNotificationDueEvent(
context.Background(),
&quota.NewAggregate("quota1", "INSTANCE").Aggregate,
QuotaRequestsAllAuthenticated.Enum(),
"id",
"url",
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
1000,
250,
),
},
},
res: res{},
},
{
name: "due events",
fields: fields{
eventstore: eventstoreExpect(
t,
expectFilter(),
expectFilter(
eventFromEventPusherWithInstanceID(
"INSTANCE",
quota.NewNotificationDueEvent(context.Background(),
&quota.NewAggregate("quota1", "INSTANCE").Aggregate,
QuotaRequestsAllAuthenticated.Enum(),
"id2",
"url",
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
1000,
250,
),
),
),
expectFilter(),
expectPush(
[]*repository.Event{
eventFromEventPusherWithInstanceID(
"INSTANCE",
quota.NewNotificationDueEvent(context.Background(),
&quota.NewAggregate("quota1", "INSTANCE").Aggregate,
QuotaRequestsAllAuthenticated.Enum(),
"id1",
"url",
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
1000,
250,
),
),
eventFromEventPusherWithInstanceID(
"INSTANCE",
quota.NewNotificationDueEvent(context.Background(),
&quota.NewAggregate("quota1", "INSTANCE").Aggregate,
QuotaRequestsAllAuthenticated.Enum(),
"id3",
"url",
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
1000,
250,
),
),
},
),
),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
dueNotifications: []*quota.NotificationDueEvent{
quota.NewNotificationDueEvent(
context.Background(),
&quota.NewAggregate("quota1", "INSTANCE").Aggregate,
QuotaRequestsAllAuthenticated.Enum(),
"id1",
"url",
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
1000,
250,
),
quota.NewNotificationDueEvent(
context.Background(),
&quota.NewAggregate("quota1", "INSTANCE").Aggregate,
QuotaRequestsAllAuthenticated.Enum(),
"id2",
"url",
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
1000,
250,
),
quota.NewNotificationDueEvent(
context.Background(),
&quota.NewAggregate("quota1", "INSTANCE").Aggregate,
QuotaRequestsAllAuthenticated.Enum(),
"id3",
"url",
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
1000,
250,
),
},
},
res: res{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &Commands{
eventstore: tt.fields.eventstore,
}
err := r.ReportQuotaUsage(tt.args.ctx, tt.args.dueNotifications)
if tt.res.err == nil {
assert.NoError(t, err)
}
if tt.res.err != nil && !tt.res.err(err) {
t.Errorf("got wrong err: %v ", err)
}
})
}
}
func TestQuotaReport_UsageNotificationSent(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
idGenerator id.Generator
}
type args struct {
ctx context.Context
dueNotification *quota.NotificationDueEvent
}
type res struct {
err func(error) bool
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
name: "usage notification sent, ok",
fields: fields{
eventstore: eventstoreExpect(
t,
expectPush(
[]*repository.Event{
eventFromEventPusherWithInstanceID(
"INSTANCE",
quota.NewNotifiedEvent(
context.Background(),
"quota1",
quota.NewNotificationDueEvent(
context.Background(),
&quota.NewAggregate("quota1", "INSTANCE").Aggregate,
QuotaRequestsAllAuthenticated.Enum(),
"id1",
"url",
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
1000,
250,
),
),
),
},
),
),
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "quota1"),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
dueNotification: quota.NewNotificationDueEvent(
context.Background(),
&quota.NewAggregate("quota1", "INSTANCE").Aggregate,
QuotaRequestsAllAuthenticated.Enum(),
"id1",
"url",
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
1000,
250,
),
},
res: res{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &Commands{
eventstore: tt.fields.eventstore,
idGenerator: tt.fields.idGenerator,
}
err := r.UsageNotificationSent(tt.args.ctx, tt.args.dueNotification)
if tt.res.err == nil {
assert.NoError(t, err)
}
if tt.res.err != nil && !tt.res.err(err) {
t.Errorf("got wrong err: %v ", err)
}
})
}
}

View File

@ -0,0 +1,638 @@
package command
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
caos_errors "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/repository"
"github.com/zitadel/zitadel/internal/id"
id_mock "github.com/zitadel/zitadel/internal/id/mock"
"github.com/zitadel/zitadel/internal/repository/quota"
)
func TestQuota_AddQuota(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
idGenerator id.Generator
}
type args struct {
ctx context.Context
addQuota *AddQuota
}
type res struct {
want *domain.ObjectDetails
err func(error) bool
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
name: "already existing",
fields: fields{
eventstore: eventstoreExpect(
t,
expectFilter(
eventFromEventPusher(
quota.NewAddedEvent(context.Background(),
&quota.NewAggregate("quota1", "INSTANCE").Aggregate,
QuotaRequestsAllAuthenticated.Enum(),
time.Now(),
30*24*time.Hour,
1000,
false,
nil,
),
),
),
),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
addQuota: &AddQuota{
Unit: QuotaRequestsAllAuthenticated,
From: time.Time{},
ResetInterval: 0,
Amount: 0,
Limit: false,
Notifications: nil,
},
},
res: res{
err: caos_errors.IsErrorAlreadyExists,
},
},
{
name: "create quota, validation fail",
fields: fields{
eventstore: eventstoreExpect(
t,
expectFilter(),
),
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "quota1"),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
addQuota: &AddQuota{
Unit: "unimplemented",
From: time.Time{},
ResetInterval: 0,
Amount: 0,
Limit: false,
Notifications: nil,
},
},
res: res{
err: func(err error) bool {
return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-OTeSh", ""))
},
},
},
{
name: "create quota, ok",
fields: fields{
eventstore: eventstoreExpect(
t,
expectFilter(),
expectPush(
[]*repository.Event{
eventFromEventPusherWithInstanceID(
"INSTANCE",
quota.NewAddedEvent(context.Background(),
&quota.NewAggregate("quota1", "INSTANCE").Aggregate,
QuotaRequestsAllAuthenticated.Enum(),
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
30*24*time.Hour,
1000,
true,
nil,
),
),
},
uniqueConstraintsFromEventConstraintWithInstanceID("INSTANCE", quota.NewAddQuotaUnitUniqueConstraint(quota.RequestsAllAuthenticated)),
),
),
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "quota1"),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
addQuota: &AddQuota{
Unit: QuotaRequestsAllAuthenticated,
From: time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
ResetInterval: 30 * 24 * time.Hour,
Amount: 1000,
Limit: true,
Notifications: nil,
},
},
res: res{
want: &domain.ObjectDetails{
ResourceOwner: "INSTANCE",
},
},
},
{
name: "removed, ok",
fields: fields{
eventstore: eventstoreExpect(
t,
expectFilter(
eventFromEventPusherWithInstanceID(
"INSTANCE",
quota.NewAddedEvent(context.Background(),
&quota.NewAggregate("quota1", "INSTANCE").Aggregate,
QuotaRequestsAllAuthenticated.Enum(),
time.Now(),
30*24*time.Hour,
1000,
true,
nil,
),
),
eventFromEventPusherWithInstanceID(
"INSTANCE",
quota.NewRemovedEvent(context.Background(),
&quota.NewAggregate("quota1", "INSTANCE").Aggregate,
QuotaRequestsAllAuthenticated.Enum(),
),
),
),
expectPush(
[]*repository.Event{
eventFromEventPusherWithInstanceID(
"INSTANCE",
quota.NewAddedEvent(context.Background(),
&quota.NewAggregate("quota1", "INSTANCE").Aggregate,
QuotaRequestsAllAuthenticated.Enum(),
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
30*24*time.Hour,
1000,
true,
nil,
),
),
},
uniqueConstraintsFromEventConstraintWithInstanceID("INSTANCE", quota.NewAddQuotaUnitUniqueConstraint(quota.RequestsAllAuthenticated)),
),
),
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "quota1"),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
addQuota: &AddQuota{
Unit: QuotaRequestsAllAuthenticated,
From: time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
ResetInterval: 30 * 24 * time.Hour,
Amount: 1000,
Limit: true,
Notifications: nil,
},
},
res: res{
want: &domain.ObjectDetails{
ResourceOwner: "INSTANCE",
},
},
},
{
name: "create quota with notifications, ok",
fields: fields{
eventstore: eventstoreExpect(
t,
expectFilter(),
expectPush(
[]*repository.Event{
eventFromEventPusherWithInstanceID(
"INSTANCE",
quota.NewAddedEvent(context.Background(),
&quota.NewAggregate("quota1", "INSTANCE").Aggregate,
QuotaRequestsAllAuthenticated.Enum(),
time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
30*24*time.Hour,
1000,
true,
[]*quota.AddedEventNotification{
{
ID: "notification1",
Percent: 20,
Repeat: false,
CallURL: "https://url.com",
},
},
),
),
},
uniqueConstraintsFromEventConstraintWithInstanceID("INSTANCE", quota.NewAddQuotaUnitUniqueConstraint(quota.RequestsAllAuthenticated)),
),
),
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "quota1", "notification1"),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
addQuota: &AddQuota{
Unit: QuotaRequestsAllAuthenticated,
From: time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC),
ResetInterval: 30 * 24 * time.Hour,
Amount: 1000,
Limit: true,
Notifications: QuotaNotifications{
{
Percent: 20,
Repeat: false,
CallURL: "https://url.com",
},
},
},
},
res: res{
want: &domain.ObjectDetails{
ResourceOwner: "INSTANCE",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &Commands{
eventstore: tt.fields.eventstore,
idGenerator: tt.fields.idGenerator,
}
got, err := r.AddQuota(tt.args.ctx, tt.args.addQuota)
if tt.res.err == nil {
assert.NoError(t, err)
}
if tt.res.err != nil && !tt.res.err(err) {
t.Errorf("got wrong err: %v ", err)
}
if tt.res.err == nil {
assert.Equal(t, tt.res.want, got)
}
})
}
}
func TestQuota_RemoveQuota(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
}
type args struct {
ctx context.Context
unit QuotaUnit
}
type res struct {
want *domain.ObjectDetails
err func(error) bool
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
name: "not found",
fields: fields{
eventstore: eventstoreExpect(
t,
expectFilter(),
),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
unit: QuotaRequestsAllAuthenticated,
},
res: res{
err: func(err error) bool {
return errors.Is(err, caos_errors.ThrowNotFound(nil, "COMMAND-WDfFf", ""))
},
},
},
{
name: "already removed",
fields: fields{
eventstore: eventstoreExpect(
t,
expectFilter(
eventFromEventPusherWithInstanceID(
"INSTANCE",
quota.NewAddedEvent(context.Background(),
&quota.NewAggregate("quota1", "INSTANCE").Aggregate,
QuotaRequestsAllAuthenticated.Enum(),
time.Now(),
30*24*time.Hour,
1000,
true,
nil,
),
),
eventFromEventPusherWithInstanceID(
"INSTANCE",
quota.NewRemovedEvent(context.Background(),
&quota.NewAggregate("quota1", "INSTANCE").Aggregate,
QuotaRequestsAllAuthenticated.Enum(),
),
),
),
),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
unit: QuotaRequestsAllAuthenticated,
},
res: res{
err: func(err error) bool {
return errors.Is(err, caos_errors.ThrowNotFound(nil, "COMMAND-WDfFf", ""))
},
},
},
{
name: "remove quota, ok",
fields: fields{
eventstore: eventstoreExpect(
t,
expectFilter(
eventFromEventPusherWithInstanceID(
"INSTANCE",
quota.NewAddedEvent(context.Background(),
&quota.NewAggregate("quota1", "INSTANCE").Aggregate,
QuotaRequestsAllAuthenticated.Enum(),
time.Now(),
30*24*time.Hour,
1000,
false,
nil,
),
),
),
expectPush(
[]*repository.Event{
eventFromEventPusherWithInstanceID(
"INSTANCE",
quota.NewRemovedEvent(context.Background(),
&quota.NewAggregate("quota1", "INSTANCE").Aggregate,
QuotaRequestsAllAuthenticated.Enum(),
),
),
},
uniqueConstraintsFromEventConstraintWithInstanceID("INSTANCE", quota.NewRemoveQuotaNameUniqueConstraint(quota.RequestsAllAuthenticated)),
),
),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "INSTANCE"),
unit: QuotaRequestsAllAuthenticated,
},
res: res{
want: &domain.ObjectDetails{
ResourceOwner: "INSTANCE",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &Commands{
eventstore: tt.fields.eventstore,
}
got, err := r.RemoveQuota(tt.args.ctx, tt.args.unit)
if tt.res.err == nil {
assert.NoError(t, err)
}
if tt.res.err != nil && !tt.res.err(err) {
t.Errorf("got wrong err: %v ", err)
}
if tt.res.err == nil {
assert.Equal(t, tt.res.want, got)
}
})
}
}
func TestQuota_QuotaNotification_validate(t *testing.T) {
type args struct {
quotaNotification *QuotaNotification
}
type res struct {
err func(error) bool
}
tests := []struct {
name string
args args
res res
}{
{
name: "notification url parse failed",
args: args{
quotaNotification: &QuotaNotification{
Percent: 20,
Repeat: false,
CallURL: "%",
},
},
res: res{
err: func(err error) bool {
return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-bZ0Fj", ""))
},
},
},
{
name: "notification url parse empty schema",
args: args{
quotaNotification: &QuotaNotification{
Percent: 20,
Repeat: false,
CallURL: "localhost:8080",
},
},
res: res{
err: func(err error) bool {
return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-HAYmN", ""))
},
},
},
{
name: "notification url parse empty host",
args: args{
quotaNotification: &QuotaNotification{
Percent: 20,
Repeat: false,
CallURL: "https://",
},
},
res: res{
err: func(err error) bool {
return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-HAYmN", ""))
},
},
},
{
name: "notification url parse percent 0",
args: args{
quotaNotification: &QuotaNotification{
Percent: 0,
Repeat: false,
CallURL: "https://localhost:8080",
},
},
res: res{
err: func(err error) bool {
return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-pBfjq", ""))
},
},
},
{
name: "notification, ok",
args: args{
quotaNotification: &QuotaNotification{
Percent: 20,
Repeat: false,
CallURL: "https://localhost:8080",
},
},
res: res{
err: nil,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.args.quotaNotification.validate()
if tt.res.err == nil {
assert.NoError(t, err)
}
if tt.res.err != nil && !tt.res.err(err) {
t.Errorf("got wrong err: %v ", err)
}
})
}
}
func TestQuota_AddQuota_validate(t *testing.T) {
type args struct {
addQuota *AddQuota
}
type res struct {
err func(error) bool
}
tests := []struct {
name string
args args
res res
}{
{
name: "notification url parse failed",
args: args{
addQuota: &AddQuota{
Unit: QuotaRequestsAllAuthenticated,
From: time.Now(),
ResetInterval: time.Minute * 10,
Amount: 100,
Limit: true,
Notifications: QuotaNotifications{
{
Percent: 20,
Repeat: false,
CallURL: "%",
},
},
},
},
res: res{
err: func(err error) bool {
return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-bZ0Fj", ""))
},
},
},
{
name: "unit unimplemented",
args: args{
addQuota: &AddQuota{
Unit: "unimplemented",
From: time.Now(),
ResetInterval: time.Minute * 10,
Amount: 100,
Limit: true,
Notifications: nil,
},
},
res: res{
err: func(err error) bool {
return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-OTeSh", ""))
},
},
},
{
name: "amount 0",
args: args{
addQuota: &AddQuota{
Unit: QuotaRequestsAllAuthenticated,
From: time.Now(),
ResetInterval: time.Minute * 10,
Amount: 0,
Limit: true,
Notifications: nil,
},
},
res: res{
err: func(err error) bool {
return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-hOKSJ", ""))
},
},
},
{
name: "reset interval under 1 min",
args: args{
addQuota: &AddQuota{
Unit: QuotaRequestsAllAuthenticated,
From: time.Now(),
ResetInterval: time.Second * 10,
Amount: 100,
Limit: true,
Notifications: nil,
},
},
res: res{
err: func(err error) bool {
return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-R5otd", ""))
},
},
},
{
name: "validate, ok",
args: args{
addQuota: &AddQuota{
Unit: QuotaRequestsAllAuthenticated,
From: time.Now(),
ResetInterval: time.Minute * 10,
Amount: 100,
Limit: false,
Notifications: nil,
},
},
res: res{
err: nil,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.args.addQuota.validate()
if tt.res.err == nil {
assert.NoError(t, err)
}
if tt.res.err != nil && !tt.res.err(err) {
t.Errorf("got wrong err: %v ", err)
}
})
}
}

View File

@ -3,6 +3,7 @@ package database
import (
"database/sql/driver"
"encoding/json"
"time"
"github.com/jackc/pgtype"
)
@ -93,3 +94,15 @@ func (m Map[V]) Value() (driver.Value, error) {
}
return json.Marshal(m)
}
type Duration time.Duration
// Scan implements the [database/sql.Scanner] interface.
func (d *Duration) Scan(src any) error {
interval := new(pgtype.Interval)
if err := interval.Scan(src); err != nil {
return err
}
*d = Duration(time.Duration(interval.Microseconds*1000) + time.Duration(interval.Days)*24*time.Hour + time.Duration(interval.Months)*30*24*time.Hour)
return nil
}

View File

@ -24,8 +24,8 @@ type DetailsMsg interface {
//
// The resource owner is compared with expected and is
// therefore the only value that has to be set.
func AssertDetails[D DetailsMsg](t testing.TB, exptected, actual D) {
wantDetails, gotDetails := exptected.GetDetails(), actual.GetDetails()
func AssertDetails[D DetailsMsg](t testing.TB, expected, actual D) {
wantDetails, gotDetails := expected.GetDetails(), actual.GetDetails()
if wantDetails == nil {
assert.Nil(t, gotDetails)
return

View File

@ -22,22 +22,17 @@ FirstInstance:
PasswordChangeRequired: false
LogStore:
Access:
Database:
Enabled: true
Debounce:
MinFrequency: 0s
MaxBulkSize: 0
Execution:
Database:
Enabled: true
Stdout:
Enabled: true
Quotas:
Access:
Enabled: true
ExhaustedCookieKey: "zitadel.quota.limiting"
ExhaustedCookieMaxAge: "60s"
Execution:
Enabled: true
Projections:
Customizations:

View File

@ -8,7 +8,11 @@ import (
_ "embed"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"reflect"
"strings"
"sync"
"time"
@ -30,6 +34,7 @@ import (
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
"github.com/zitadel/zitadel/internal/net"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/webauthn"
"github.com/zitadel/zitadel/pkg/grpc/admin"
@ -71,6 +76,11 @@ const (
UserPassword = "VeryS3cret!"
)
const (
PortMilestoneServer = "8081"
PortQuotaServer = "8082"
)
// User information with a Personal Access Token.
type User struct {
*query.User
@ -101,6 +111,11 @@ type Tester struct {
Organisation *query.Org
Users InstanceUserMap
MilestoneChan chan []byte
milestoneServer *httptest.Server
QuotaNotificationChan chan []byte
quotaNotificationServer *httptest.Server
Client Client
WebAuthN *webauthn.Client
wg sync.WaitGroup // used for shutdown
@ -271,6 +286,8 @@ func (s *Tester) Done() {
s.Shutdown <- os.Interrupt
s.wg.Wait()
s.milestoneServer.Close()
s.quotaNotificationServer.Close()
}
// NewTester start a new Zitadel server by passing the default commandline.
@ -279,13 +296,13 @@ func (s *Tester) Done() {
// INTEGRATION_DB_FLAVOR environment variable and can have the values "cockroach"
// or "postgres". Defaults to "cockroach".
//
// The deault Instance and Organisation are read from the DB and system
// The default Instance and Organisation are read from the DB and system
// users are created as needed.
//
// After the server is started, a [grpc.ClientConn] will be created and
// the server is polled for it's health status.
//
// Note: the database must already be setup and intialized before
// Note: the database must already be setup and initialized before
// using NewTester. See the CONTRIBUTING.md document for details.
func NewTester(ctx context.Context) *Tester {
args := strings.Split(commandLine, " ")
@ -311,6 +328,13 @@ func NewTester(ctx context.Context) *Tester {
tester := Tester{
Users: make(InstanceUserMap),
}
tester.MilestoneChan = make(chan []byte, 100)
tester.milestoneServer, err = runMilestoneServer(ctx, tester.MilestoneChan)
logging.OnError(err).Fatal()
tester.QuotaNotificationChan = make(chan []byte, 100)
tester.quotaNotificationServer, err = runQuotaServer(ctx, tester.QuotaNotificationChan)
logging.OnError(err).Fatal()
tester.wg.Add(1)
go func(wg *sync.WaitGroup) {
logging.OnError(cmd.Execute()).Fatal()
@ -328,7 +352,6 @@ func NewTester(ctx context.Context) *Tester {
tester.createMachineUserOrgOwner(ctx)
tester.createMachineUserInstanceOwner(ctx)
tester.WebAuthN = webauthn.NewClient(tester.Config.WebAuthNName, tester.Config.ExternalDomain, "https://"+tester.Host())
return &tester
}
@ -338,3 +361,51 @@ func Contexts(timeout time.Duration) (ctx, errCtx context.Context, cancel contex
ctx, cancel = context.WithTimeout(context.Background(), timeout)
return ctx, errCtx, cancel
}
func runMilestoneServer(ctx context.Context, bodies chan []byte) (*httptest.Server, error) {
mockServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if r.Header.Get("single-value") != "single-value" {
http.Error(w, "single-value header not set", http.StatusInternalServerError)
return
}
if reflect.DeepEqual(r.Header.Get("multi-value"), "multi-value-1,multi-value-2") {
http.Error(w, "single-value header not set", http.StatusInternalServerError)
return
}
bodies <- body
w.WriteHeader(http.StatusOK)
}))
config := net.ListenConfig()
listener, err := config.Listen(ctx, "tcp", ":"+PortMilestoneServer)
if err != nil {
return nil, err
}
mockServer.Listener = listener
mockServer.Start()
return mockServer, nil
}
func runQuotaServer(ctx context.Context, bodies chan []byte) (*httptest.Server, error) {
mockServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
bodies <- body
w.WriteHeader(http.StatusOK)
}))
config := net.ListenConfig()
listener, err := config.Listen(ctx, "tcp", ":"+PortQuotaServer)
if err != nil {
return nil, err
}
mockServer.Listener = listener
mockServer.Start()
return mockServer, nil
}

View File

@ -6,6 +6,9 @@ type Configs struct {
}
type Config struct {
Database *EmitterConfig
Stdout *EmitterConfig
Stdout *StdConfig
}
type StdConfig struct {
Enabled bool
}

View File

@ -9,19 +9,17 @@ import (
"github.com/zitadel/logging"
)
type bulkSink interface {
sendBulk(ctx context.Context, bulk []LogRecord) error
type bulkSink[T LogRecord[T]] interface {
SendBulk(ctx context.Context, bulk []T) error
}
var _ bulkSink = bulkSinkFunc(nil)
type bulkSinkFunc[T LogRecord[T]] func(ctx context.Context, bulk []T) error
type bulkSinkFunc func(ctx context.Context, items []LogRecord) error
func (s bulkSinkFunc) sendBulk(ctx context.Context, items []LogRecord) error {
return s(ctx, items)
func (s bulkSinkFunc[T]) SendBulk(ctx context.Context, bulk []T) error {
return s(ctx, bulk)
}
type debouncer struct {
type debouncer[T LogRecord[T]] struct {
// Storing context.Context in a struct is generally bad practice
// https://go.dev/blog/context-and-structs
// However, debouncer starts a go routine that triggers side effects itself.
@ -33,8 +31,8 @@ type debouncer struct {
ticker *clock.Ticker
mux sync.Mutex
cfg DebouncerConfig
storage bulkSink
cache []LogRecord
storage bulkSink[T]
cache []T
cacheLen uint
}
@ -43,8 +41,8 @@ type DebouncerConfig struct {
MaxBulkSize uint
}
func newDebouncer(binarySignaledCtx context.Context, cfg DebouncerConfig, clock clock.Clock, ship bulkSink) *debouncer {
a := &debouncer{
func newDebouncer[T LogRecord[T]](binarySignaledCtx context.Context, cfg DebouncerConfig, clock clock.Clock, ship bulkSink[T]) *debouncer[T] {
a := &debouncer[T]{
binarySignaledCtx: binarySignaledCtx,
clock: clock,
cfg: cfg,
@ -58,7 +56,7 @@ func newDebouncer(binarySignaledCtx context.Context, cfg DebouncerConfig, clock
return a
}
func (d *debouncer) add(item LogRecord) {
func (d *debouncer[T]) add(item T) {
d.mux.Lock()
defer d.mux.Unlock()
d.cache = append(d.cache, item)
@ -69,13 +67,13 @@ func (d *debouncer) add(item LogRecord) {
}
}
func (d *debouncer) ship() {
func (d *debouncer[T]) ship() {
if d.cacheLen == 0 {
return
}
d.mux.Lock()
defer d.mux.Unlock()
if err := d.storage.sendBulk(d.binarySignaledCtx, d.cache); err != nil {
if err := d.storage.SendBulk(d.binarySignaledCtx, d.cache); err != nil {
logging.WithError(err).WithField("size", len(d.cache)).Error("storing bulk failed")
}
d.cache = nil
@ -85,7 +83,7 @@ func (d *debouncer) ship() {
}
}
func (d *debouncer) shipOnTicks() {
func (d *debouncer[T]) shipOnTicks() {
for range d.ticker.C {
d.ship()
}

View File

@ -2,56 +2,52 @@ package logstore
import (
"context"
"fmt"
"time"
"github.com/benbjohnson/clock"
"github.com/zitadel/logging"
)
type EmitterConfig struct {
Enabled bool
Keep time.Duration
CleanupInterval time.Duration
Debounce *DebouncerConfig
}
type emitter struct {
type emitter[T LogRecord[T]] struct {
enabled bool
ctx context.Context
debouncer *debouncer
emitter LogEmitter
debouncer *debouncer[T]
emitter LogEmitter[T]
clock clock.Clock
}
type LogRecord interface {
Normalize() LogRecord
type LogRecord[T any] interface {
Normalize() T
}
type LogRecordFunc func() LogRecord
type LogRecordFunc[T any] func() T
func (r LogRecordFunc) Normalize() LogRecord {
func (r LogRecordFunc[T]) Normalize() T {
return r()
}
type LogEmitter interface {
Emit(ctx context.Context, bulk []LogRecord) error
type LogEmitter[T LogRecord[T]] interface {
Emit(ctx context.Context, bulk []T) error
}
type LogEmitterFunc func(ctx context.Context, bulk []LogRecord) error
type LogEmitterFunc[T LogRecord[T]] func(ctx context.Context, bulk []T) error
func (l LogEmitterFunc) Emit(ctx context.Context, bulk []LogRecord) error {
func (l LogEmitterFunc[T]) Emit(ctx context.Context, bulk []T) error {
return l(ctx, bulk)
}
type LogCleanupper interface {
LogEmitter
type LogCleanupper[T LogRecord[T]] interface {
Cleanup(ctx context.Context, keep time.Duration) error
LogEmitter[T]
}
// NewEmitter accepts Clock from github.com/benbjohnson/clock so we can control timers and tickers in the unit tests
func NewEmitter(ctx context.Context, clock clock.Clock, cfg *EmitterConfig, logger LogEmitter) (*emitter, error) {
svc := &emitter{
func NewEmitter[T LogRecord[T]](ctx context.Context, clock clock.Clock, cfg *EmitterConfig, logger LogEmitter[T]) (*emitter[T], error) {
svc := &emitter[T]{
enabled: cfg != nil && cfg.Enabled,
ctx: ctx,
emitter: logger,
@ -63,36 +59,12 @@ func NewEmitter(ctx context.Context, clock clock.Clock, cfg *EmitterConfig, logg
}
if cfg.Debounce != nil && (cfg.Debounce.MinFrequency > 0 || cfg.Debounce.MaxBulkSize > 0) {
svc.debouncer = newDebouncer(ctx, *cfg.Debounce, clock, newStorageBulkSink(svc.emitter))
}
cleanupper, ok := logger.(LogCleanupper)
if !ok {
if cfg.Keep != 0 {
return nil, fmt.Errorf("cleaning up for this storage type is not supported, so keep duration must be 0, but is %d", cfg.Keep)
}
if cfg.CleanupInterval != 0 {
return nil, fmt.Errorf("cleaning up for this storage type is not supported, so cleanup interval duration must be 0, but is %d", cfg.Keep)
}
return svc, nil
}
if cfg.Keep != 0 && cfg.CleanupInterval != 0 {
go svc.startCleanupping(cleanupper, cfg.CleanupInterval, cfg.Keep)
svc.debouncer = newDebouncer[T](ctx, *cfg.Debounce, clock, newStorageBulkSink(svc.emitter))
}
return svc, nil
}
func (s *emitter) startCleanupping(cleanupper LogCleanupper, cleanupInterval, keep time.Duration) {
for range s.clock.Tick(cleanupInterval) {
if err := cleanupper.Cleanup(s.ctx, keep); err != nil {
logging.WithError(err).Error("cleaning up logs failed")
}
}
}
func (s *emitter) Emit(ctx context.Context, record LogRecord) (err error) {
func (s *emitter[T]) Emit(ctx context.Context, record T) (err error) {
if !s.enabled {
return nil
}
@ -102,11 +74,11 @@ func (s *emitter) Emit(ctx context.Context, record LogRecord) (err error) {
return nil
}
return s.emitter.Emit(ctx, []LogRecord{record})
return s.emitter.Emit(ctx, []T{record})
}
func newStorageBulkSink(emitter LogEmitter) bulkSinkFunc {
return func(ctx context.Context, bulk []LogRecord) error {
func newStorageBulkSink[T LogRecord[T]](emitter LogEmitter[T]) bulkSinkFunc[T] {
return func(ctx context.Context, bulk []T) error {
return emitter.Emit(ctx, bulk)
}
}

View File

@ -3,167 +3,91 @@ package access
import (
"context"
"database/sql"
"fmt"
"net/http"
"strings"
"errors"
"time"
"github.com/Masterminds/squirrel"
"github.com/zitadel/logging"
"google.golang.org/grpc/codes"
"github.com/zitadel/zitadel/internal/api/call"
zitadel_http "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/database"
caos_errors "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/logstore/record"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/repository/quota"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
const (
accessLogsTable = "logstore.access"
accessTimestampCol = "log_date"
accessProtocolCol = "protocol"
accessRequestURLCol = "request_url"
accessResponseStatusCol = "response_status"
accessRequestHeadersCol = "request_headers"
accessResponseHeadersCol = "response_headers"
accessInstanceIdCol = "instance_id"
accessProjectIdCol = "project_id"
accessRequestedDomainCol = "requested_domain"
accessRequestedHostCol = "requested_host"
)
var _ logstore.UsageQuerier = (*databaseLogStorage)(nil)
var _ logstore.LogCleanupper = (*databaseLogStorage)(nil)
var _ logstore.UsageStorer[*record.AccessLog] = (*databaseLogStorage)(nil)
type databaseLogStorage struct {
dbClient *database.DB
commands *command.Commands
queries *query.Queries
}
func NewDatabaseLogStorage(dbClient *database.DB) *databaseLogStorage {
return &databaseLogStorage{dbClient: dbClient}
func NewDatabaseLogStorage(dbClient *database.DB, commands *command.Commands, queries *query.Queries) *databaseLogStorage {
return &databaseLogStorage{dbClient: dbClient, commands: commands, queries: queries}
}
func (l *databaseLogStorage) QuotaUnit() quota.Unit {
return quota.RequestsAllAuthenticated
}
func (l *databaseLogStorage) Emit(ctx context.Context, bulk []logstore.LogRecord) error {
func (l *databaseLogStorage) Emit(ctx context.Context, bulk []*record.AccessLog) error {
if len(bulk) == 0 {
return nil
}
builder := squirrel.Insert(accessLogsTable).
Columns(
accessTimestampCol,
accessProtocolCol,
accessRequestURLCol,
accessResponseStatusCol,
accessRequestHeadersCol,
accessResponseHeadersCol,
accessInstanceIdCol,
accessProjectIdCol,
accessRequestedDomainCol,
accessRequestedHostCol,
).
PlaceholderFormat(squirrel.Dollar)
for idx := range bulk {
item := bulk[idx].(*Record)
builder = builder.Values(
item.LogDate,
item.Protocol,
item.RequestURL,
item.ResponseStatus,
item.RequestHeaders,
item.ResponseHeaders,
item.InstanceID,
item.ProjectID,
item.RequestedDomain,
item.RequestedHost,
)
}
stmt, args, err := builder.ToSql()
if err != nil {
return caos_errors.ThrowInternal(err, "ACCESS-KOS7I", "Errors.Internal")
}
result, err := l.dbClient.ExecContext(ctx, stmt, args...)
if err != nil {
return caos_errors.ThrowInternal(err, "ACCESS-alnT9", "Errors.Access.StorageFailed")
}
rows, err := result.RowsAffected()
if err != nil {
return caos_errors.ThrowInternal(err, "ACCESS-7KIpL", "Errors.Internal")
}
logging.WithFields("rows", rows).Debug("successfully stored access logs")
return nil
return l.incrementUsage(ctx, bulk)
}
func (l *databaseLogStorage) QueryUsage(ctx context.Context, instanceId string, start time.Time) (uint64, error) {
stmt, args, err := squirrel.Select(
fmt.Sprintf("count(%s)", accessInstanceIdCol),
).
From(accessLogsTable + l.dbClient.Timetravel(call.Took(ctx))).
Where(squirrel.And{
squirrel.Eq{accessInstanceIdCol: instanceId},
squirrel.GtOrEq{accessTimestampCol: start},
squirrel.Expr(fmt.Sprintf(`%s #>> '{%s,0}' = '[REDACTED]'`, accessRequestHeadersCol, strings.ToLower(zitadel_http.Authorization))),
squirrel.NotLike{accessRequestURLCol: "%/zitadel.system.v1.SystemService/%"},
squirrel.NotLike{accessRequestURLCol: "%/system/v1/%"},
squirrel.Or{
squirrel.And{
squirrel.Eq{accessProtocolCol: HTTP},
squirrel.NotEq{accessResponseStatusCol: http.StatusForbidden},
squirrel.NotEq{accessResponseStatusCol: http.StatusInternalServerError},
squirrel.NotEq{accessResponseStatusCol: http.StatusTooManyRequests},
},
squirrel.And{
squirrel.Eq{accessProtocolCol: GRPC},
squirrel.NotEq{accessResponseStatusCol: codes.PermissionDenied},
squirrel.NotEq{accessResponseStatusCol: codes.Internal},
squirrel.NotEq{accessResponseStatusCol: codes.ResourceExhausted},
},
},
}).
PlaceholderFormat(squirrel.Dollar).
ToSql()
func (l *databaseLogStorage) incrementUsage(ctx context.Context, bulk []*record.AccessLog) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if err != nil {
return 0, caos_errors.ThrowInternal(err, "ACCESS-V9Sde", "Errors.Internal")
byInstance := make(map[string][]*record.AccessLog)
for _, r := range bulk {
if r.InstanceID != "" {
byInstance[r.InstanceID] = append(byInstance[r.InstanceID], r)
}
var count uint64
err = l.dbClient.
QueryRowContext(ctx,
func(row *sql.Row) error {
return row.Scan(&count)
},
stmt, args...,
)
if err != nil {
return 0, caos_errors.ThrowInternal(err, "ACCESS-pBPrM", "Errors.Logstore.Access.ScanFailed")
}
return count, nil
}
func (l *databaseLogStorage) Cleanup(ctx context.Context, keep time.Duration) error {
stmt, args, err := squirrel.Delete(accessLogsTable).
Where(squirrel.LtOrEq{accessTimestampCol: time.Now().Add(-keep)}).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return caos_errors.ThrowInternal(err, "ACCESS-2oTh6", "Errors.Internal")
for instanceID, instanceBulk := range byInstance {
q, getQuotaErr := l.queries.GetQuota(ctx, instanceID, quota.RequestsAllAuthenticated)
if errors.Is(getQuotaErr, sql.ErrNoRows) {
continue
}
err = errors.Join(err, getQuotaErr)
if getQuotaErr != nil {
continue
}
sum, incrementErr := l.incrementUsageFromAccessLogs(ctx, instanceID, q.CurrentPeriodStart, instanceBulk)
err = errors.Join(err, incrementErr)
if incrementErr != nil {
continue
}
notifications, getNotificationErr := l.queries.GetDueQuotaNotifications(ctx, instanceID, quota.RequestsAllAuthenticated, q, q.CurrentPeriodStart, sum)
err = errors.Join(err, getNotificationErr)
if getNotificationErr != nil || len(notifications) == 0 {
continue
}
ctx = authz.WithInstanceID(ctx, instanceID)
reportErr := l.commands.ReportQuotaUsage(ctx, notifications)
err = errors.Join(err, reportErr)
if reportErr != nil {
continue
}
}
execCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
_, err = l.dbClient.ExecContext(execCtx, stmt, args...)
return err
}
func (l *databaseLogStorage) incrementUsageFromAccessLogs(ctx context.Context, instanceID string, periodStart time.Time, records []*record.AccessLog) (sum uint64, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
var count uint64
for _, r := range records {
if r.IsAuthenticated() {
count++
}
}
return projection.QuotaProjection.IncrementUsage(ctx, quota.RequestsAllAuthenticated, instanceID, periodStart, count)
}

View File

@ -1,97 +0,0 @@
package access
import (
"strings"
"time"
zitadel_http "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/logstore"
)
var _ logstore.LogRecord = (*Record)(nil)
type Record struct {
LogDate time.Time `json:"logDate"`
Protocol Protocol `json:"protocol"`
RequestURL string `json:"requestUrl"`
ResponseStatus uint32 `json:"responseStatus"`
// RequestHeaders are plain maps so varying implementation
// between HTTP and gRPC don't interfere with each other
RequestHeaders map[string][]string `json:"requestHeaders"`
// ResponseHeaders are plain maps so varying implementation
// between HTTP and gRPC don't interfere with each other
ResponseHeaders map[string][]string `json:"responseHeaders"`
InstanceID string `json:"instanceId"`
ProjectID string `json:"projectId"`
RequestedDomain string `json:"requestedDomain"`
RequestedHost string `json:"requestedHost"`
}
type Protocol uint8
const (
GRPC Protocol = iota
HTTP
redacted = "[REDACTED]"
)
func (a Record) Normalize() logstore.LogRecord {
a.RequestedDomain = cutString(a.RequestedDomain, 200)
a.RequestURL = cutString(a.RequestURL, 200)
a.RequestHeaders = normalizeHeaders(a.RequestHeaders, strings.ToLower(zitadel_http.Authorization), "grpcgateway-authorization", "cookie", "grpcgateway-cookie")
a.ResponseHeaders = normalizeHeaders(a.ResponseHeaders, "set-cookie")
return &a
}
// normalizeHeaders lowers all header keys and redacts secrets
func normalizeHeaders(header map[string][]string, redactKeysLower ...string) map[string][]string {
return pruneKeys(redactKeys(lowerKeys(header), redactKeysLower...))
}
func lowerKeys(header map[string][]string) map[string][]string {
lower := make(map[string][]string, len(header))
for k, v := range header {
lower[strings.ToLower(k)] = v
}
return lower
}
func redactKeys(header map[string][]string, redactKeysLower ...string) map[string][]string {
redactedKeys := make(map[string][]string, len(header))
for k, v := range header {
redactedKeys[k] = v
}
for _, redactKey := range redactKeysLower {
if _, ok := redactedKeys[redactKey]; ok {
redactedKeys[redactKey] = []string{redacted}
}
}
return redactedKeys
}
const maxValuesPerKey = 10
func pruneKeys(header map[string][]string) map[string][]string {
prunedKeys := make(map[string][]string, len(header))
for key, value := range header {
valueItems := make([]string, 0, maxValuesPerKey)
for i, valueItem := range value {
// Max 10 header values per key
if i > maxValuesPerKey {
break
}
// Max 200 value length
valueItems = append(valueItems, cutString(valueItem, 200))
}
prunedKeys[key] = valueItems
}
return prunedKeys
}
func cutString(str string, pos int) string {
if len(str) <= pos {
return str
}
return str[:pos-1]
}

View File

@ -3,142 +3,84 @@ package execution
import (
"context"
"database/sql"
"fmt"
"errors"
"math"
"time"
"github.com/Masterminds/squirrel"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/database"
caos_errors "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/logstore/record"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/repository/quota"
)
const (
executionLogsTable = "logstore.execution"
executionTimestampCol = "log_date"
executionTookCol = "took"
executionMessageCol = "message"
executionLogLevelCol = "loglevel"
executionInstanceIdCol = "instance_id"
executionActionIdCol = "action_id"
executionMetadataCol = "metadata"
)
var _ logstore.UsageQuerier = (*databaseLogStorage)(nil)
var _ logstore.LogCleanupper = (*databaseLogStorage)(nil)
var _ logstore.UsageStorer[*record.ExecutionLog] = (*databaseLogStorage)(nil)
type databaseLogStorage struct {
dbClient *database.DB
commands *command.Commands
queries *query.Queries
}
func NewDatabaseLogStorage(dbClient *database.DB) *databaseLogStorage {
return &databaseLogStorage{dbClient: dbClient}
func NewDatabaseLogStorage(dbClient *database.DB, commands *command.Commands, queries *query.Queries) *databaseLogStorage {
return &databaseLogStorage{dbClient: dbClient, commands: commands, queries: queries}
}
func (l *databaseLogStorage) QuotaUnit() quota.Unit {
return quota.ActionsAllRunsSeconds
}
func (l *databaseLogStorage) Emit(ctx context.Context, bulk []logstore.LogRecord) error {
func (l *databaseLogStorage) Emit(ctx context.Context, bulk []*record.ExecutionLog) error {
if len(bulk) == 0 {
return nil
}
builder := squirrel.Insert(executionLogsTable).
Columns(
executionTimestampCol,
executionTookCol,
executionMessageCol,
executionLogLevelCol,
executionInstanceIdCol,
executionActionIdCol,
executionMetadataCol,
).
PlaceholderFormat(squirrel.Dollar)
for idx := range bulk {
item := bulk[idx].(*Record)
var took interface{}
if item.Took > 0 {
took = item.Took
}
builder = builder.Values(
item.LogDate,
took,
item.Message,
item.LogLevel,
item.InstanceID,
item.ActionID,
item.Metadata,
)
}
stmt, args, err := builder.ToSql()
if err != nil {
return caos_errors.ThrowInternal(err, "EXEC-KOS7I", "Errors.Internal")
}
result, err := l.dbClient.ExecContext(ctx, stmt, args...)
if err != nil {
return caos_errors.ThrowInternal(err, "EXEC-0j6i5", "Errors.Access.StorageFailed")
}
rows, err := result.RowsAffected()
if err != nil {
return caos_errors.ThrowInternal(err, "EXEC-MGchJ", "Errors.Internal")
}
logging.WithFields("rows", rows).Debug("successfully stored execution logs")
return nil
return l.incrementUsage(ctx, bulk)
}
func (l *databaseLogStorage) QueryUsage(ctx context.Context, instanceId string, start time.Time) (uint64, error) {
stmt, args, err := squirrel.Select(
fmt.Sprintf("COALESCE(SUM(%s)::INT,0)", executionTookCol),
).
From(executionLogsTable + l.dbClient.Timetravel(call.Took(ctx))).
Where(squirrel.And{
squirrel.Eq{executionInstanceIdCol: instanceId},
squirrel.GtOrEq{executionTimestampCol: start},
squirrel.NotEq{executionTookCol: nil},
}).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return 0, caos_errors.ThrowInternal(err, "EXEC-DXtzg", "Errors.Internal")
func (l *databaseLogStorage) incrementUsage(ctx context.Context, bulk []*record.ExecutionLog) (err error) {
byInstance := make(map[string][]*record.ExecutionLog)
for _, r := range bulk {
if r.InstanceID != "" {
byInstance[r.InstanceID] = append(byInstance[r.InstanceID], r)
}
}
for instanceID, instanceBulk := range byInstance {
q, getQuotaErr := l.queries.GetQuota(ctx, instanceID, quota.ActionsAllRunsSeconds)
if errors.Is(getQuotaErr, sql.ErrNoRows) {
continue
}
err = errors.Join(err, getQuotaErr)
if getQuotaErr != nil {
continue
}
sum, incrementErr := l.incrementUsageFromExecutionLogs(ctx, instanceID, q.CurrentPeriodStart, instanceBulk)
err = errors.Join(err, incrementErr)
if incrementErr != nil {
continue
}
var durationSeconds uint64
err = l.dbClient.
QueryRowContext(ctx,
func(row *sql.Row) error {
return row.Scan(&durationSeconds)
},
stmt, args...,
)
if err != nil {
return 0, caos_errors.ThrowInternal(err, "EXEC-Ad8nP", "Errors.Logstore.Execution.ScanFailed")
notifications, getNotificationErr := l.queries.GetDueQuotaNotifications(ctx, instanceID, quota.ActionsAllRunsSeconds, q, q.CurrentPeriodStart, sum)
err = errors.Join(err, getNotificationErr)
if getNotificationErr != nil || len(notifications) == 0 {
continue
}
ctx = authz.WithInstanceID(ctx, instanceID)
reportErr := l.commands.ReportQuotaUsage(ctx, notifications)
err = errors.Join(err, reportErr)
if reportErr != nil {
continue
}
return durationSeconds, nil
}
func (l *databaseLogStorage) Cleanup(ctx context.Context, keep time.Duration) error {
stmt, args, err := squirrel.Delete(executionLogsTable).
Where(squirrel.LtOrEq{executionTimestampCol: time.Now().Add(-keep)}).
PlaceholderFormat(squirrel.Dollar).
ToSql()
if err != nil {
return caos_errors.ThrowInternal(err, "EXEC-Bja8V", "Errors.Internal")
}
execCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
_, err = l.dbClient.ExecContext(execCtx, stmt, args...)
return err
}
func (l *databaseLogStorage) incrementUsageFromExecutionLogs(ctx context.Context, instanceID string, periodStart time.Time, records []*record.ExecutionLog) (sum uint64, err error) {
var total time.Duration
for _, r := range records {
total += r.Took
}
return projection.QuotaProjection.IncrementUsage(ctx, quota.ActionsAllRunsSeconds, instanceID, periodStart, uint64(math.Floor(total.Seconds())))
}

View File

@ -1,89 +0,0 @@
package mock
import (
"context"
"sync"
"time"
"github.com/benbjohnson/clock"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/repository/quota"
)
var _ logstore.UsageQuerier = (*InmemLogStorage)(nil)
var _ logstore.LogCleanupper = (*InmemLogStorage)(nil)
type InmemLogStorage struct {
mux sync.Mutex
clock clock.Clock
emitted []*record
bulks []int
}
func NewInMemoryStorage(clock clock.Clock) *InmemLogStorage {
return &InmemLogStorage{
clock: clock,
emitted: make([]*record, 0),
bulks: make([]int, 0),
}
}
func (l *InmemLogStorage) QuotaUnit() quota.Unit {
return quota.Unimplemented
}
func (l *InmemLogStorage) Emit(_ context.Context, bulk []logstore.LogRecord) error {
if len(bulk) == 0 {
return nil
}
l.mux.Lock()
defer l.mux.Unlock()
for idx := range bulk {
l.emitted = append(l.emitted, bulk[idx].(*record))
}
l.bulks = append(l.bulks, len(bulk))
return nil
}
func (l *InmemLogStorage) QueryUsage(_ context.Context, _ string, start time.Time) (uint64, error) {
l.mux.Lock()
defer l.mux.Unlock()
var count uint64
for _, r := range l.emitted {
if r.ts.After(start) {
count++
}
}
return count, nil
}
func (l *InmemLogStorage) Cleanup(_ context.Context, keep time.Duration) error {
l.mux.Lock()
defer l.mux.Unlock()
clean := make([]*record, 0)
from := l.clock.Now().Add(-(keep + 1))
for _, r := range l.emitted {
if r.ts.After(from) {
clean = append(clean, r)
}
}
l.emitted = clean
return nil
}
func (l *InmemLogStorage) Bulks() []int {
l.mux.Lock()
defer l.mux.Unlock()
return l.bulks
}
func (l *InmemLogStorage) Len() int {
l.mux.Lock()
defer l.mux.Unlock()
return len(l.emitted)
}

View File

@ -9,8 +9,8 @@ import (
"github.com/zitadel/zitadel/internal/logstore"
)
func NewStdoutEmitter() logstore.LogEmitter {
return logstore.LogEmitterFunc(func(ctx context.Context, bulk []logstore.LogRecord) error {
func NewStdoutEmitter[T logstore.LogRecord[T]]() logstore.LogEmitter[T] {
return logstore.LogEmitterFunc[T](func(ctx context.Context, bulk []T) error {
for idx := range bulk {
bytes, err := json.Marshal(bulk[idx])
if err != nil {

View File

@ -4,7 +4,7 @@ import (
"time"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/repository/quota"
"github.com/zitadel/zitadel/internal/query"
)
type emitterOption func(config *logstore.EmitterConfig)
@ -12,8 +12,6 @@ type emitterOption func(config *logstore.EmitterConfig)
func emitterConfig(options ...emitterOption) *logstore.EmitterConfig {
cfg := &logstore.EmitterConfig{
Enabled: true,
Keep: time.Hour,
CleanupInterval: time.Hour,
Debounce: &logstore.DebouncerConfig{
MinFrequency: 0,
MaxBulkSize: 0,
@ -37,17 +35,10 @@ func withDisabled() emitterOption {
}
}
func withCleanupping(keep, interval time.Duration) emitterOption {
return func(c *logstore.EmitterConfig) {
c.Keep = keep
c.CleanupInterval = interval
}
}
type quotaOption func(config *query.Quota)
type quotaOption func(config *quota.AddedEvent)
func quotaConfig(quotaOptions ...quotaOption) quota.AddedEvent {
q := &quota.AddedEvent{
func quotaConfig(quotaOptions ...quotaOption) *query.Quota {
q := &query.Quota{
Amount: 90,
Limit: false,
ResetInterval: 90 * time.Second,
@ -56,18 +47,18 @@ func quotaConfig(quotaOptions ...quotaOption) quota.AddedEvent {
for _, opt := range quotaOptions {
opt(q)
}
return *q
return q
}
func withAmountAndInterval(n uint64) quotaOption {
return func(c *quota.AddedEvent) {
return func(c *query.Quota) {
c.Amount = n
c.ResetInterval = time.Duration(n) * time.Second
}
}
func withLimiting() quotaOption {
return func(c *quota.AddedEvent) {
return func(c *query.Quota) {
c.Limit = true
}
}

View File

@ -0,0 +1,120 @@
package mock
import (
"context"
"sync"
"time"
"github.com/benbjohnson/clock"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/repository/quota"
)
var _ logstore.UsageStorer[*Record] = (*InmemLogStorage)(nil)
var _ logstore.LogCleanupper[*Record] = (*InmemLogStorage)(nil)
var _ logstore.Queries = (*InmemLogStorage)(nil)
type InmemLogStorage struct {
mux sync.Mutex
clock clock.Clock
emitted []*Record
bulks []int
quota *query.Quota
}
func NewInMemoryStorage(clock clock.Clock, quota *query.Quota) *InmemLogStorage {
return &InmemLogStorage{
clock: clock,
emitted: make([]*Record, 0),
bulks: make([]int, 0),
quota: quota,
}
}
func (l *InmemLogStorage) QuotaUnit() quota.Unit {
return quota.Unimplemented
}
func (l *InmemLogStorage) Emit(_ context.Context, bulk []*Record) error {
if len(bulk) == 0 {
return nil
}
l.mux.Lock()
defer l.mux.Unlock()
l.emitted = append(l.emitted, bulk...)
l.bulks = append(l.bulks, len(bulk))
return nil
}
func (l *InmemLogStorage) QueryUsage(_ context.Context, _ string, start time.Time) (uint64, error) {
l.mux.Lock()
defer l.mux.Unlock()
var count uint64
for _, r := range l.emitted {
if r.ts.After(start) {
count++
}
}
return count, nil
}
func (l *InmemLogStorage) Cleanup(_ context.Context, keep time.Duration) error {
l.mux.Lock()
defer l.mux.Unlock()
clean := make([]*Record, 0)
from := l.clock.Now().Add(-(keep + 1))
for _, r := range l.emitted {
if r.ts.After(from) {
clean = append(clean, r)
}
}
l.emitted = clean
return nil
}
func (l *InmemLogStorage) Bulks() []int {
l.mux.Lock()
defer l.mux.Unlock()
return l.bulks
}
func (l *InmemLogStorage) Len() int {
l.mux.Lock()
defer l.mux.Unlock()
return len(l.emitted)
}
func (l *InmemLogStorage) GetQuota(ctx context.Context, instanceID string, unit quota.Unit) (qu *query.Quota, err error) {
return l.quota, nil
}
func (l *InmemLogStorage) GetQuotaUsage(ctx context.Context, instanceID string, unit quota.Unit, periodStart time.Time) (usage uint64, err error) {
return uint64(l.Len()), nil
}
func (l *InmemLogStorage) GetRemainingQuotaUsage(ctx context.Context, instanceID string, unit quota.Unit) (remaining *uint64, err error) {
if !l.quota.Limit {
return nil, nil
}
var r uint64
used := uint64(l.Len())
if used > l.quota.Amount {
return &r, nil
}
r = l.quota.Amount - used
return &r, nil
}
func (l *InmemLogStorage) GetDueQuotaNotifications(ctx context.Context, instanceID string, unit quota.Unit, qu *query.Quota, periodStart time.Time, usedAbs uint64) (dueNotifications []*quota.NotificationDueEvent, err error) {
return nil, nil
}
func (l *InmemLogStorage) ReportQuotaUsage(ctx context.Context, dueNotifications []*quota.NotificationDueEvent) error {
return nil
}

View File

@ -8,18 +8,18 @@ import (
"github.com/zitadel/zitadel/internal/logstore"
)
var _ logstore.LogRecord = (*record)(nil)
var _ logstore.LogRecord[*Record] = (*Record)(nil)
func NewRecord(clock clock.Clock) *record {
return &record{ts: clock.Now()}
func NewRecord(clock clock.Clock) *Record {
return &Record{ts: clock.Now()}
}
type record struct {
type Record struct {
ts time.Time
redacted bool
}
func (r record) Normalize() logstore.LogRecord {
func (r Record) Normalize() *Record {
r.redacted = true
return &r
}

View File

@ -1,28 +0,0 @@
package mock
import (
"context"
"time"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/repository/quota"
)
var _ logstore.QuotaQuerier = (*inmemReporter)(nil)
type inmemReporter struct {
config *quota.AddedEvent
startPeriod time.Time
}
func NewNoopQuerier(quota *quota.AddedEvent, startPeriod time.Time) *inmemReporter {
return &inmemReporter{config: quota, startPeriod: startPeriod}
}
func (i *inmemReporter) GetCurrentQuotaPeriod(context.Context, string, quota.Unit) (*quota.AddedEvent, time.Time, error) {
return i.config, i.startPeriod, nil
}
func (*inmemReporter) GetDueQuotaNotifications(context.Context, *quota.AddedEvent, time.Time, uint64) ([]*quota.NotificationDueEvent, error) {
return nil, nil
}

View File

@ -0,0 +1,135 @@
package record
import (
"net/http"
"strings"
"time"
"google.golang.org/grpc/codes"
zitadel_http "github.com/zitadel/zitadel/internal/api/http"
)
type AccessLog struct {
LogDate time.Time `json:"logDate"`
Protocol AccessProtocol `json:"protocol"`
RequestURL string `json:"requestUrl"`
ResponseStatus uint32 `json:"responseStatus"`
// RequestHeaders and ResponseHeaders are plain maps so varying implementations
// between HTTP and gRPC don't interfere with each other
RequestHeaders map[string][]string `json:"requestHeaders"`
ResponseHeaders map[string][]string `json:"responseHeaders"`
InstanceID string `json:"instanceId"`
ProjectID string `json:"projectId"`
RequestedDomain string `json:"requestedDomain"`
RequestedHost string `json:"requestedHost"`
// NotCountable can be used by the logging service to explicitly stating,
// that the request must not increase the amount of countable (authenticated) requests
NotCountable bool `json:"-"`
normalized bool `json:"-"`
}
type AccessProtocol uint8
const (
GRPC AccessProtocol = iota
HTTP
redacted = "[REDACTED]"
)
var (
unaccountableEndpoints = []string{
"/zitadel.system.v1.SystemService/",
"/zitadel.admin.v1.AdminService/Healthz",
"/zitadel.management.v1.ManagementService/Healthz",
"/zitadel.management.v1.ManagementService/GetOIDCInformation",
"/zitadel.auth.v1.AuthService/Healthz",
}
)
func (a AccessLog) IsAuthenticated() bool {
if a.NotCountable {
return false
}
if !a.normalized {
panic("access log not normalized, Normalize() must be called before IsAuthenticated()")
}
_, hasHTTPAuthHeader := a.RequestHeaders[strings.ToLower(zitadel_http.Authorization)]
// ignore requests, which were unauthorized or do not require an authorization (even if one was sent)
// also ignore if the limit was already reached or if the server returned an internal error
// not that endpoints paths are only checked with the gRPC representation as HTTP (gateway) will not log them
return hasHTTPAuthHeader &&
(a.Protocol == HTTP &&
a.ResponseStatus != http.StatusInternalServerError &&
a.ResponseStatus != http.StatusTooManyRequests &&
a.ResponseStatus != http.StatusUnauthorized) ||
(a.Protocol == GRPC &&
a.ResponseStatus != uint32(codes.Internal) &&
a.ResponseStatus != uint32(codes.ResourceExhausted) &&
a.ResponseStatus != uint32(codes.Unauthenticated) &&
!a.isUnaccountableEndpoint())
}
func (a AccessLog) isUnaccountableEndpoint() bool {
for _, endpoint := range unaccountableEndpoints {
if strings.HasPrefix(a.RequestURL, endpoint) {
return true
}
}
return false
}
func (a AccessLog) Normalize() *AccessLog {
a.RequestedDomain = cutString(a.RequestedDomain, 200)
a.RequestURL = cutString(a.RequestURL, 200)
a.RequestHeaders = normalizeHeaders(a.RequestHeaders, strings.ToLower(zitadel_http.Authorization), "grpcgateway-authorization", "cookie", "grpcgateway-cookie")
a.ResponseHeaders = normalizeHeaders(a.ResponseHeaders, "set-cookie")
a.normalized = true
return &a
}
// normalizeHeaders lowers all header keys and redacts secrets
func normalizeHeaders(header map[string][]string, redactKeysLower ...string) map[string][]string {
return pruneKeys(redactKeys(lowerKeys(header), redactKeysLower...))
}
func lowerKeys(header map[string][]string) map[string][]string {
lower := make(map[string][]string, len(header))
for k, v := range header {
lower[strings.ToLower(k)] = v
}
return lower
}
func redactKeys(header map[string][]string, redactKeysLower ...string) map[string][]string {
redactedKeys := make(map[string][]string, len(header))
for k, v := range header {
redactedKeys[k] = v
}
for _, redactKey := range redactKeysLower {
if _, ok := redactedKeys[redactKey]; ok {
redactedKeys[redactKey] = []string{redacted}
}
}
return redactedKeys
}
const maxValuesPerKey = 10
func pruneKeys(header map[string][]string) map[string][]string {
prunedKeys := make(map[string][]string, len(header))
for key, value := range header {
valueItems := make([]string, 0, maxValuesPerKey)
for i, valueItem := range value {
// Max 10 header values per key
if i > maxValuesPerKey {
break
}
// Max 200 value length
valueItems = append(valueItems, cutString(valueItem, 200))
}
prunedKeys[key] = valueItems
}
return prunedKeys
}

View File

@ -1,20 +1,18 @@
package access_test
package record
import (
"reflect"
"testing"
"github.com/zitadel/zitadel/internal/logstore/emitters/access"
)
func TestRecord_Normalize(t *testing.T) {
tests := []struct {
name string
record access.Record
want *access.Record
record AccessLog
want *AccessLog
}{{
name: "headers with certain keys should be redacted",
record: access.Record{
record: AccessLog{
RequestHeaders: map[string][]string{
"authorization": {"AValue"},
"grpcgateway-authorization": {"AValue"},
@ -24,7 +22,7 @@ func TestRecord_Normalize(t *testing.T) {
"set-cookie": {"AValue"},
},
},
want: &access.Record{
want: &AccessLog{
RequestHeaders: map[string][]string{
"authorization": {"[REDACTED]"},
"grpcgateway-authorization": {"[REDACTED]"},
@ -36,22 +34,22 @@ func TestRecord_Normalize(t *testing.T) {
},
}, {
name: "header keys should be lower cased",
record: access.Record{
record: AccessLog{
RequestHeaders: map[string][]string{"AKey": {"AValue"}},
ResponseHeaders: map[string][]string{"AKey": {"AValue"}}},
want: &access.Record{
want: &AccessLog{
RequestHeaders: map[string][]string{"akey": {"AValue"}},
ResponseHeaders: map[string][]string{"akey": {"AValue"}}},
}, {
name: "an already prune record should stay unchanged",
record: access.Record{
record: AccessLog{
RequestURL: "https://my.zitadel.cloud/",
RequestHeaders: map[string][]string{
"authorization": {"[REDACTED]"},
},
ResponseHeaders: map[string][]string{},
},
want: &access.Record{
want: &AccessLog{
RequestURL: "https://my.zitadel.cloud/",
RequestHeaders: map[string][]string{
"authorization": {"[REDACTED]"},
@ -60,17 +58,18 @@ func TestRecord_Normalize(t *testing.T) {
},
}, {
name: "empty record should stay empty",
record: access.Record{
record: AccessLog{
RequestHeaders: map[string][]string{},
ResponseHeaders: map[string][]string{},
},
want: &access.Record{
want: &AccessLog{
RequestHeaders: map[string][]string{},
ResponseHeaders: map[string][]string{},
},
}}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.want.normalized = true
if got := tt.record.Normalize(); !reflect.DeepEqual(got, tt.want) {
t.Errorf("Normalize() = %v, want %v", got, tt.want)
}

View File

@ -1,16 +1,12 @@
package execution
package record
import (
"time"
"github.com/sirupsen/logrus"
"github.com/zitadel/zitadel/internal/logstore"
)
var _ logstore.LogRecord = (*Record)(nil)
type Record struct {
type ExecutionLog struct {
LogDate time.Time `json:"logDate"`
Took time.Duration `json:"took"`
Message string `json:"message"`
@ -20,14 +16,7 @@ type Record struct {
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
func (e Record) Normalize() logstore.LogRecord {
func (e ExecutionLog) Normalize() *ExecutionLog {
e.Message = cutString(e.Message, 2000)
return &e
}
func cutString(str string, pos int) string {
if len(str) <= pos {
return str
}
return str[:pos]
}

View File

@ -0,0 +1,8 @@
package record
func cutString(str string, pos int) string {
if len(str) <= pos {
return str
}
return str[:pos-1]
}

View File

@ -2,124 +2,70 @@ package logstore
import (
"context"
"math"
"time"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/repository/quota"
)
const handleThresholdTimeout = time.Minute
type QuotaQuerier interface {
GetCurrentQuotaPeriod(ctx context.Context, instanceID string, unit quota.Unit) (config *quota.AddedEvent, periodStart time.Time, err error)
GetDueQuotaNotifications(ctx context.Context, config *quota.AddedEvent, periodStart time.Time, used uint64) ([]*quota.NotificationDueEvent, error)
}
type UsageQuerier interface {
LogEmitter
type UsageStorer[T LogRecord[T]] interface {
LogEmitter[T]
QuotaUnit() quota.Unit
QueryUsage(ctx context.Context, instanceId string, start time.Time) (uint64, error)
}
type UsageReporter interface {
Report(ctx context.Context, notifications []*quota.NotificationDueEvent) (err error)
}
type UsageReporterFunc func(context.Context, []*quota.NotificationDueEvent) (err error)
func (u UsageReporterFunc) Report(ctx context.Context, notifications []*quota.NotificationDueEvent) (err error) {
return u(ctx, notifications)
}
type Service struct {
usageQuerier UsageQuerier
quotaQuerier QuotaQuerier
usageReporter UsageReporter
enabledSinks []*emitter
type Service[T LogRecord[T]] struct {
queries Queries
usageStorer UsageStorer[T]
enabledSinks []*emitter[T]
sinkEnabled bool
reportingEnabled bool
}
func New(quotaQuerier QuotaQuerier, usageReporter UsageReporter, usageQuerierSink *emitter, additionalSink ...*emitter) *Service {
var usageQuerier UsageQuerier
type Queries interface {
GetRemainingQuotaUsage(ctx context.Context, instanceID string, unit quota.Unit) (remaining *uint64, err error)
}
func New[T LogRecord[T]](queries Queries, usageQuerierSink *emitter[T], additionalSink ...*emitter[T]) *Service[T] {
var usageStorer UsageStorer[T]
if usageQuerierSink != nil {
usageQuerier = usageQuerierSink.emitter.(UsageQuerier)
usageStorer = usageQuerierSink.emitter.(UsageStorer[T])
}
svc := &Service{
svc := &Service[T]{
queries: queries,
reportingEnabled: usageQuerierSink != nil && usageQuerierSink.enabled,
usageQuerier: usageQuerier,
quotaQuerier: quotaQuerier,
usageReporter: usageReporter,
usageStorer: usageStorer,
}
for _, s := range append([]*emitter{usageQuerierSink}, additionalSink...) {
for _, s := range append([]*emitter[T]{usageQuerierSink}, additionalSink...) {
if s != nil && s.enabled {
svc.enabledSinks = append(svc.enabledSinks, s)
}
}
svc.sinkEnabled = len(svc.enabledSinks) > 0
return svc
}
func (s *Service) Enabled() bool {
func (s *Service[T]) Enabled() bool {
return s.sinkEnabled
}
func (s *Service) Handle(ctx context.Context, record LogRecord) {
func (s *Service[T]) Handle(ctx context.Context, record T) {
for _, sink := range s.enabledSinks {
logging.OnError(sink.Emit(ctx, record.Normalize())).WithField("record", record).Warn("failed to emit log record")
}
}
func (s *Service) Limit(ctx context.Context, instanceID string) *uint64 {
func (s *Service[T]) Limit(ctx context.Context, instanceID string) *uint64 {
var err error
defer func() {
logging.OnError(err).Warn("failed to check is usage should be limited")
logging.OnError(err).Warn("failed to check if usage should be limited")
}()
if !s.reportingEnabled || instanceID == "" {
return nil
}
quota, periodStart, err := s.quotaQuerier.GetCurrentQuotaPeriod(ctx, instanceID, s.usageQuerier.QuotaUnit())
if err != nil || quota == nil {
return nil
}
usage, err := s.usageQuerier.QueryUsage(ctx, instanceID, periodStart)
remaining, err := s.queries.GetRemainingQuotaUsage(ctx, instanceID, s.usageStorer.QuotaUnit())
if err != nil {
// TODO: shouldn't we just limit then or return the error and decide there?
return nil
}
go s.handleThresholds(ctx, quota, periodStart, usage)
var remaining *uint64
if quota.Limit {
r := uint64(math.Max(0, float64(quota.Amount)-float64(usage)))
remaining = &r
}
return remaining
}
func (s *Service) handleThresholds(ctx context.Context, quota *quota.AddedEvent, periodStart time.Time, usage uint64) {
var err error
defer func() {
logging.OnError(err).Warn("handling quota thresholds failed")
}()
detatchedCtx, cancel := context.WithTimeout(authz.Detach(ctx), handleThresholdTimeout)
defer cancel()
notifications, err := s.quotaQuerier.GetDueQuotaNotifications(detatchedCtx, quota, periodStart, usage)
if err != nil || len(notifications) == 0 {
return
}
err = s.usageReporter.Report(detatchedCtx, notifications)
}

View File

@ -14,9 +14,8 @@ import (
"github.com/benbjohnson/clock"
"github.com/zitadel/zitadel/internal/logstore"
emittermock "github.com/zitadel/zitadel/internal/logstore/emitters/mock"
quotaqueriermock "github.com/zitadel/zitadel/internal/logstore/quotaqueriers/mock"
"github.com/zitadel/zitadel/internal/repository/quota"
emittermock "github.com/zitadel/zitadel/internal/logstore/mock"
"github.com/zitadel/zitadel/internal/query"
)
const (
@ -27,7 +26,7 @@ const (
type args struct {
mainSink *logstore.EmitterConfig
secondarySink *logstore.EmitterConfig
config quota.AddedEvent
config *query.Quota
}
type want struct {
@ -137,28 +136,6 @@ func TestService(t *testing.T) {
len: 0,
},
},
}, {
name: "cleanupping works",
args: args{
mainSink: emitterConfig(withCleanupping(17*time.Second, 28*time.Second)),
secondarySink: emitterConfig(withDebouncerConfig(&logstore.DebouncerConfig{
MinFrequency: 0,
MaxBulkSize: 15,
}), withCleanupping(5*time.Second, 47*time.Second)),
config: quotaConfig(),
},
want: want{
enabled: true,
remaining: nil,
mainSink: wantSink{
bulks: repeat(1, 60),
len: 21,
},
secondarySink: wantSink{
bulks: repeat(15, 4),
len: 18,
},
},
}, {
name: "when quota has a limit of 90, 30 are remaining",
args: args{
@ -232,27 +209,24 @@ func runTest(t *testing.T, name string, args args, want want) bool {
})
}
func given(t *testing.T, args args, want want) (context.Context, *clock.Mock, *emittermock.InmemLogStorage, *emittermock.InmemLogStorage, *logstore.Service) {
func given(t *testing.T, args args, want want) (context.Context, *clock.Mock, *emittermock.InmemLogStorage, *emittermock.InmemLogStorage, *logstore.Service[*emittermock.Record]) {
ctx := context.Background()
clock := clock.NewMock()
periodStart := time.Time{}
clock.Set(args.config.From)
mainStorage := emittermock.NewInMemoryStorage(clock)
mainEmitter, err := logstore.NewEmitter(ctx, clock, args.mainSink, mainStorage)
mainStorage := emittermock.NewInMemoryStorage(clock, args.config)
mainEmitter, err := logstore.NewEmitter[*emittermock.Record](ctx, clock, args.mainSink, mainStorage)
if err != nil {
t.Errorf("expected no error but got %v", err)
}
secondaryStorage := emittermock.NewInMemoryStorage(clock)
secondaryEmitter, err := logstore.NewEmitter(ctx, clock, args.secondarySink, secondaryStorage)
secondaryStorage := emittermock.NewInMemoryStorage(clock, args.config)
secondaryEmitter, err := logstore.NewEmitter[*emittermock.Record](ctx, clock, args.secondarySink, secondaryStorage)
if err != nil {
t.Errorf("expected no error but got %v", err)
}
svc := logstore.New(
quotaqueriermock.NewNoopQuerier(&args.config, periodStart),
logstore.UsageReporterFunc(func(context.Context, []*quota.NotificationDueEvent) error { return nil }),
svc := logstore.New[*emittermock.Record](
mainStorage,
mainEmitter,
secondaryEmitter)
@ -262,7 +236,7 @@ func given(t *testing.T, args args, want want) (context.Context, *clock.Mock, *e
return ctx, clock, mainStorage, secondaryStorage, svc
}
func when(svc *logstore.Service, ctx context.Context, clock *clock.Mock) *uint64 {
func when(svc *logstore.Service[*emittermock.Record], ctx context.Context, clock *clock.Mock) *uint64 {
var remaining *uint64
for i := 0; i < ticks; i++ {
svc.Handle(ctx, emittermock.NewRecord(clock))

View File

@ -1,11 +1,11 @@
//go:build !integration
package start
package net
import (
"net"
)
func listenConfig() *net.ListenConfig {
func ListenConfig() *net.ListenConfig {
return &net.ListenConfig{}
}

View File

@ -1,6 +1,6 @@
//go:build integration
package start
package net
import (
"net"
@ -9,7 +9,7 @@ import (
"golang.org/x/sys/unix"
)
func listenConfig() *net.ListenConfig {
func ListenConfig() *net.ListenConfig {
return &net.ListenConfig{
Control: reusePort,
}

View File

@ -20,11 +20,13 @@ var (
func TestMain(m *testing.M) {
os.Exit(func() int {
ctx, _, cancel := integration.Contexts(5 * time.Minute)
CTX = ctx
defer cancel()
CTX = ctx
Tester = integration.NewTester(ctx)
SystemCTX = Tester.WithAuthorization(ctx, integration.SystemUser)
defer Tester.Done()
SystemCTX = Tester.WithAuthorization(ctx, integration.SystemUser)
return m.Run()
}())
}

View File

@ -5,11 +5,6 @@ package handlers_test
import (
"bytes"
"encoding/json"
"io"
"net"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"
@ -18,49 +13,27 @@ import (
)
func TestServer_TelemetryPushMilestones(t *testing.T) {
bodies := make(chan []byte, 0)
mockServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
t.Error(err)
}
if r.Header.Get("single-value") != "single-value" {
t.Error("single-value header not set")
}
if reflect.DeepEqual(r.Header.Get("multi-value"), "multi-value-1,multi-value-2") {
t.Error("single-value header not set")
}
bodies <- body
w.WriteHeader(http.StatusOK)
}))
listener, err := net.Listen("tcp", "localhost:8081")
if err != nil {
t.Fatal(err)
}
mockServer.Listener = listener
mockServer.Start()
t.Cleanup(mockServer.Close)
primaryDomain, instanceID, iamOwnerCtx := Tester.UseIsolatedInstance(CTX, SystemCTX)
t.Log("testing against instance with primary domain", primaryDomain)
awaitMilestone(t, bodies, primaryDomain, "InstanceCreated")
awaitMilestone(t, Tester.MilestoneChan, primaryDomain, "InstanceCreated")
project, err := Tester.Client.Mgmt.AddProject(iamOwnerCtx, &management.AddProjectRequest{Name: "integration"})
if err != nil {
t.Fatal(err)
}
awaitMilestone(t, bodies, primaryDomain, "ProjectCreated")
awaitMilestone(t, Tester.MilestoneChan, primaryDomain, "ProjectCreated")
if _, err = Tester.Client.Mgmt.AddOIDCApp(iamOwnerCtx, &management.AddOIDCAppRequest{
ProjectId: project.GetId(),
Name: "integration",
}); err != nil {
t.Fatal(err)
}
awaitMilestone(t, bodies, primaryDomain, "ApplicationCreated")
awaitMilestone(t, Tester.MilestoneChan, primaryDomain, "ApplicationCreated")
// TODO: trigger and await milestone AuthenticationSucceededOnInstance
// TODO: trigger and await milestone AuthenticationSucceededOnApplication
if _, err = Tester.Client.System.RemoveInstance(SystemCTX, &system.RemoveInstanceRequest{InstanceId: instanceID}); err != nil {
t.Fatal(err)
}
awaitMilestone(t, bodies, primaryDomain, "InstanceDeleted")
awaitMilestone(t, Tester.MilestoneChan, primaryDomain, "InstanceDeleted")
}
func awaitMilestone(t *testing.T, bodies chan []byte, primaryDomain, expectMilestoneType string) {

View File

@ -13,6 +13,7 @@ import (
key_repo "github.com/zitadel/zitadel/internal/repository/keypair"
"github.com/zitadel/zitadel/internal/repository/org"
proj_repo "github.com/zitadel/zitadel/internal/repository/project"
quota_repo "github.com/zitadel/zitadel/internal/repository/quota"
usr_repo "github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/repository/usergrant"
)
@ -29,6 +30,7 @@ func eventstoreExpect(t *testing.T, expects ...expect) *eventstore.Eventstore {
org.RegisterEventMappers(es)
usr_repo.RegisterEventMappers(es)
proj_repo.RegisterEventMappers(es)
quota_repo.RegisterEventMappers(es)
usergrant.RegisterEventMappers(es)
key_repo.RegisterEventMappers(es)
action_repo.RegisterEventMappers(es)

View File

@ -69,6 +69,7 @@ var (
SessionProjection *sessionProjection
AuthRequestProjection *authRequestProjection
MilestoneProjection *milestoneProjection
QuotaProjection *quotaProjection
)
type projection interface {
@ -148,6 +149,7 @@ func Create(ctx context.Context, sqlClient *database.DB, es *eventstore.Eventsto
SessionProjection = newSessionProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["sessions"]))
AuthRequestProjection = newAuthRequestProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["auth_requests"]))
MilestoneProjection = newMilestoneProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["milestones"]))
QuotaProjection = newQuotaProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["quotas"]))
newProjectionsList()
return nil
}
@ -247,5 +249,6 @@ func newProjectionsList() {
SessionProjection,
AuthRequestProjection,
MilestoneProjection,
QuotaProjection,
}
}

View File

@ -0,0 +1,285 @@
package projection
import (
"context"
"time"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler"
"github.com/zitadel/zitadel/internal/eventstore/handler/crdb"
"github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/quota"
)
const (
QuotasProjectionTable = "projections.quotas"
QuotaPeriodsProjectionTable = QuotasProjectionTable + "_" + quotaPeriodsTableSuffix
QuotaNotificationsTable = QuotasProjectionTable + "_" + quotaNotificationsTableSuffix
QuotaColumnID = "id"
QuotaColumnInstanceID = "instance_id"
QuotaColumnUnit = "unit"
QuotaColumnAmount = "amount"
QuotaColumnFrom = "from_anchor"
QuotaColumnInterval = "interval"
QuotaColumnLimit = "limit_usage"
quotaPeriodsTableSuffix = "periods"
QuotaPeriodColumnInstanceID = "instance_id"
QuotaPeriodColumnUnit = "unit"
QuotaPeriodColumnStart = "start"
QuotaPeriodColumnUsage = "usage"
quotaNotificationsTableSuffix = "notifications"
QuotaNotificationColumnInstanceID = "instance_id"
QuotaNotificationColumnUnit = "unit"
QuotaNotificationColumnID = "id"
QuotaNotificationColumnCallURL = "call_url"
QuotaNotificationColumnPercent = "percent"
QuotaNotificationColumnRepeat = "repeat"
QuotaNotificationColumnLatestDuePeriodStart = "latest_due_period_start"
QuotaNotificationColumnNextDueThreshold = "next_due_threshold"
)
const (
incrementQuotaStatement = `INSERT INTO projections.quotas_periods` +
` (instance_id, unit, start, usage)` +
` VALUES ($1, $2, $3, $4) ON CONFLICT (instance_id, unit, start)` +
` DO UPDATE SET usage = projections.quotas_periods.usage + excluded.usage RETURNING usage`
)
type quotaProjection struct {
crdb.StatementHandler
client *database.DB
}
func newQuotaProjection(ctx context.Context, config crdb.StatementHandlerConfig) *quotaProjection {
p := new(quotaProjection)
config.ProjectionName = QuotasProjectionTable
config.Reducers = p.reducers()
config.InitCheck = crdb.NewMultiTableCheck(
crdb.NewTable(
[]*crdb.Column{
crdb.NewColumn(QuotaColumnID, crdb.ColumnTypeText),
crdb.NewColumn(QuotaColumnInstanceID, crdb.ColumnTypeText),
crdb.NewColumn(QuotaColumnUnit, crdb.ColumnTypeEnum),
crdb.NewColumn(QuotaColumnAmount, crdb.ColumnTypeInt64),
crdb.NewColumn(QuotaColumnFrom, crdb.ColumnTypeTimestamp),
crdb.NewColumn(QuotaColumnInterval, crdb.ColumnTypeInterval),
crdb.NewColumn(QuotaColumnLimit, crdb.ColumnTypeBool),
},
crdb.NewPrimaryKey(QuotaColumnInstanceID, QuotaColumnUnit),
),
crdb.NewSuffixedTable(
[]*crdb.Column{
crdb.NewColumn(QuotaPeriodColumnInstanceID, crdb.ColumnTypeText),
crdb.NewColumn(QuotaPeriodColumnUnit, crdb.ColumnTypeEnum),
crdb.NewColumn(QuotaPeriodColumnStart, crdb.ColumnTypeTimestamp),
crdb.NewColumn(QuotaPeriodColumnUsage, crdb.ColumnTypeInt64),
},
crdb.NewPrimaryKey(QuotaPeriodColumnInstanceID, QuotaPeriodColumnUnit, QuotaPeriodColumnStart),
quotaPeriodsTableSuffix,
),
crdb.NewSuffixedTable(
[]*crdb.Column{
crdb.NewColumn(QuotaNotificationColumnInstanceID, crdb.ColumnTypeText),
crdb.NewColumn(QuotaNotificationColumnUnit, crdb.ColumnTypeEnum),
crdb.NewColumn(QuotaNotificationColumnID, crdb.ColumnTypeText),
crdb.NewColumn(QuotaNotificationColumnCallURL, crdb.ColumnTypeText),
crdb.NewColumn(QuotaNotificationColumnPercent, crdb.ColumnTypeInt64),
crdb.NewColumn(QuotaNotificationColumnRepeat, crdb.ColumnTypeBool),
crdb.NewColumn(QuotaNotificationColumnLatestDuePeriodStart, crdb.ColumnTypeTimestamp, crdb.Nullable()),
crdb.NewColumn(QuotaNotificationColumnNextDueThreshold, crdb.ColumnTypeInt64, crdb.Nullable()),
},
crdb.NewPrimaryKey(QuotaNotificationColumnInstanceID, QuotaNotificationColumnUnit, QuotaNotificationColumnID),
quotaNotificationsTableSuffix,
),
)
p.StatementHandler = crdb.NewStatementHandler(ctx, config)
p.client = config.Client
return p
}
func (q *quotaProjection) reducers() []handler.AggregateReducer {
return []handler.AggregateReducer{
{
Aggregate: instance.AggregateType,
EventRedusers: []handler.EventReducer{
{
Event: instance.InstanceRemovedEventType,
Reduce: q.reduceInstanceRemoved,
},
},
},
{
Aggregate: quota.AggregateType,
EventRedusers: []handler.EventReducer{
{
Event: quota.AddedEventType,
Reduce: q.reduceQuotaAdded,
},
},
},
{
Aggregate: quota.AggregateType,
EventRedusers: []handler.EventReducer{
{
Event: quota.RemovedEventType,
Reduce: q.reduceQuotaRemoved,
},
},
},
{
Aggregate: quota.AggregateType,
EventRedusers: []handler.EventReducer{
{
Event: quota.NotificationDueEventType,
Reduce: q.reduceQuotaNotificationDue,
},
},
},
{
Aggregate: quota.AggregateType,
EventRedusers: []handler.EventReducer{
{
Event: quota.NotifiedEventType,
Reduce: q.reduceQuotaNotified,
},
},
},
}
}
func (q *quotaProjection) reduceQuotaNotified(event eventstore.Event) (*handler.Statement, error) {
return crdb.NewNoOpStatement(event), nil
}
func (q *quotaProjection) reduceQuotaAdded(event eventstore.Event) (*handler.Statement, error) {
e, err := assertEvent[*quota.AddedEvent](event)
if err != nil {
return nil, err
}
createStatements := make([]func(e eventstore.Event) crdb.Exec, len(e.Notifications)+1)
createStatements[0] = crdb.AddCreateStatement(
[]handler.Column{
handler.NewCol(QuotaColumnID, e.Aggregate().ID),
handler.NewCol(QuotaColumnInstanceID, e.Aggregate().InstanceID),
handler.NewCol(QuotaColumnUnit, e.Unit),
handler.NewCol(QuotaColumnAmount, e.Amount),
handler.NewCol(QuotaColumnFrom, e.From),
handler.NewCol(QuotaColumnInterval, e.ResetInterval),
handler.NewCol(QuotaColumnLimit, e.Limit),
})
for i := range e.Notifications {
notification := e.Notifications[i]
createStatements[i+1] = crdb.AddCreateStatement(
[]handler.Column{
handler.NewCol(QuotaNotificationColumnInstanceID, e.Aggregate().InstanceID),
handler.NewCol(QuotaNotificationColumnUnit, e.Unit),
handler.NewCol(QuotaNotificationColumnID, notification.ID),
handler.NewCol(QuotaNotificationColumnCallURL, notification.CallURL),
handler.NewCol(QuotaNotificationColumnPercent, notification.Percent),
handler.NewCol(QuotaNotificationColumnRepeat, notification.Repeat),
},
crdb.WithTableSuffix(quotaNotificationsTableSuffix),
)
}
return crdb.NewMultiStatement(e, createStatements...), nil
}
func (q *quotaProjection) reduceQuotaNotificationDue(event eventstore.Event) (*handler.Statement, error) {
e, err := assertEvent[*quota.NotificationDueEvent](event)
if err != nil {
return nil, err
}
return crdb.NewUpdateStatement(e,
[]handler.Column{
handler.NewCol(QuotaNotificationColumnLatestDuePeriodStart, e.PeriodStart),
handler.NewCol(QuotaNotificationColumnNextDueThreshold, e.Threshold+100), // next due_threshold is always the reached + 100 => percent (e.g. 90) in the next bucket (e.g. 190)
},
[]handler.Condition{
handler.NewCond(QuotaNotificationColumnInstanceID, e.Aggregate().InstanceID),
handler.NewCond(QuotaNotificationColumnUnit, e.Unit),
handler.NewCond(QuotaNotificationColumnID, e.ID),
},
crdb.WithTableSuffix(quotaNotificationsTableSuffix),
), nil
}
func (q *quotaProjection) reduceQuotaRemoved(event eventstore.Event) (*handler.Statement, error) {
e, err := assertEvent[*quota.RemovedEvent](event)
if err != nil {
return nil, err
}
return crdb.NewMultiStatement(
e,
crdb.AddDeleteStatement(
[]handler.Condition{
handler.NewCond(QuotaPeriodColumnInstanceID, e.Aggregate().InstanceID),
handler.NewCond(QuotaPeriodColumnUnit, e.Unit),
},
crdb.WithTableSuffix(quotaPeriodsTableSuffix),
),
crdb.AddDeleteStatement(
[]handler.Condition{
handler.NewCond(QuotaNotificationColumnInstanceID, e.Aggregate().InstanceID),
handler.NewCond(QuotaNotificationColumnUnit, e.Unit),
},
crdb.WithTableSuffix(quotaNotificationsTableSuffix),
),
crdb.AddDeleteStatement(
[]handler.Condition{
handler.NewCond(QuotaColumnInstanceID, e.Aggregate().InstanceID),
handler.NewCond(QuotaColumnUnit, e.Unit),
},
),
), nil
}
func (q *quotaProjection) reduceInstanceRemoved(event eventstore.Event) (*handler.Statement, error) {
// we only assert the event to make sure it is the correct type
e, err := assertEvent[*instance.InstanceRemovedEvent](event)
if err != nil {
return nil, err
}
return crdb.NewMultiStatement(
e,
crdb.AddDeleteStatement(
[]handler.Condition{
handler.NewCond(QuotaPeriodColumnInstanceID, e.Aggregate().InstanceID),
},
crdb.WithTableSuffix(quotaPeriodsTableSuffix),
),
crdb.AddDeleteStatement(
[]handler.Condition{
handler.NewCond(QuotaNotificationColumnInstanceID, e.Aggregate().InstanceID),
},
crdb.WithTableSuffix(quotaNotificationsTableSuffix),
),
crdb.AddDeleteStatement(
[]handler.Condition{
handler.NewCond(QuotaColumnInstanceID, e.Aggregate().InstanceID),
},
),
), nil
}
func (q *quotaProjection) IncrementUsage(ctx context.Context, unit quota.Unit, instanceID string, periodStart time.Time, count uint64) (sum uint64, err error) {
if count == 0 {
return 0, nil
}
err = q.client.DB.QueryRowContext(
ctx,
incrementQuotaStatement,
instanceID, unit, periodStart, count,
).Scan(&sum)
if err != nil {
return 0, errors.ThrowInternalf(err, "PROJ-SJL3h", "incrementing usage for unit %d failed for at least one quota period", unit)
}
return sum, err
}

View File

@ -0,0 +1,321 @@
package projection
import (
"context"
"regexp"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler"
"github.com/zitadel/zitadel/internal/eventstore/repository"
"github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/quota"
)
func TestQuotasProjection_reduces(t *testing.T) {
type args struct {
event func(t *testing.T) eventstore.Event
}
tests := []struct {
name string
args args
reduce func(event eventstore.Event) (*handler.Statement, error)
want wantReduce
}{
{
name: "reduceQuotaAdded",
args: args{
event: getEvent(testEvent(
repository.EventType(quota.AddedEventType),
quota.AggregateType,
[]byte(`{
"unit": 1,
"amount": 10,
"limit": true,
"from": "2023-01-01T00:00:00Z",
"interval": 300000000000
}`),
), quota.AddedEventMapper),
},
reduce: (&quotaProjection{}).reduceQuotaAdded,
want: wantReduce{
aggregateType: eventstore.AggregateType("quota"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "INSERT INTO projections.quotas (id, instance_id, unit, amount, from_anchor, interval, limit_usage) VALUES ($1, $2, $3, $4, $5, $6, $7)",
expectedArgs: []interface{}{
"agg-id",
"instance-id",
quota.RequestsAllAuthenticated,
uint64(10),
time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC),
time.Minute * 5,
true,
},
},
},
},
},
},
{
name: "reduceQuotaAdded with notification",
args: args{
event: getEvent(testEvent(
repository.EventType(quota.AddedEventType),
quota.AggregateType,
[]byte(`{
"unit": 1,
"amount": 10,
"limit": true,
"from": "2023-01-01T00:00:00Z",
"interval": 300000000000,
"notifications": [
{
"id": "id",
"percent": 100,
"repeat": true,
"callURL": "url"
}
]
}`),
), quota.AddedEventMapper),
},
reduce: (&quotaProjection{}).reduceQuotaAdded,
want: wantReduce{
aggregateType: eventstore.AggregateType("quota"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "INSERT INTO projections.quotas (id, instance_id, unit, amount, from_anchor, interval, limit_usage) VALUES ($1, $2, $3, $4, $5, $6, $7)",
expectedArgs: []interface{}{
"agg-id",
"instance-id",
quota.RequestsAllAuthenticated,
uint64(10),
time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC),
time.Minute * 5,
true,
},
},
{
expectedStmt: "INSERT INTO projections.quotas_notifications (instance_id, unit, id, call_url, percent, repeat) VALUES ($1, $2, $3, $4, $5, $6)",
expectedArgs: []interface{}{
"instance-id",
quota.RequestsAllAuthenticated,
"id",
"url",
uint16(100),
true,
},
},
},
},
},
},
{
name: "reduceQuotaNotificationDue",
args: args{
event: getEvent(testEvent(
repository.EventType(quota.NotificationDueEventType),
quota.AggregateType,
[]byte(`{
"id": "id",
"unit": 1,
"callURL": "url",
"periodStart": "2023-01-01T00:00:00Z",
"threshold": 200,
"usage": 100
}`),
), quota.NotificationDueEventMapper),
},
reduce: (&quotaProjection{}).reduceQuotaNotificationDue,
want: wantReduce{
aggregateType: eventstore.AggregateType("quota"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.quotas_notifications SET (latest_due_period_start, next_due_threshold) = ($1, $2) WHERE (instance_id = $3) AND (unit = $4) AND (id = $5)",
expectedArgs: []interface{}{
time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC),
uint16(300),
"instance-id",
quota.RequestsAllAuthenticated,
"id",
},
},
},
},
},
},
{
name: "reduceQuotaRemoved",
args: args{
event: getEvent(testEvent(
repository.EventType(quota.RemovedEventType),
quota.AggregateType,
[]byte(`{
"unit": 1
}`),
), quota.RemovedEventMapper),
},
reduce: (&quotaProjection{}).reduceQuotaRemoved,
want: wantReduce{
aggregateType: eventstore.AggregateType("quota"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "DELETE FROM projections.quotas_periods WHERE (instance_id = $1) AND (unit = $2)",
expectedArgs: []interface{}{
"instance-id",
quota.RequestsAllAuthenticated,
},
},
{
expectedStmt: "DELETE FROM projections.quotas_notifications WHERE (instance_id = $1) AND (unit = $2)",
expectedArgs: []interface{}{
"instance-id",
quota.RequestsAllAuthenticated,
},
},
{
expectedStmt: "DELETE FROM projections.quotas WHERE (instance_id = $1) AND (unit = $2)",
expectedArgs: []interface{}{
"instance-id",
quota.RequestsAllAuthenticated,
},
},
},
},
},
}, {
name: "reduceInstanceRemoved",
args: args{
event: getEvent(testEvent(
repository.EventType(instance.InstanceRemovedEventType),
instance.AggregateType,
[]byte(`{
"name": "name"
}`),
), instance.InstanceRemovedEventMapper),
},
reduce: (&quotaProjection{}).reduceInstanceRemoved,
want: wantReduce{
aggregateType: eventstore.AggregateType("instance"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "DELETE FROM projections.quotas_periods WHERE (instance_id = $1)",
expectedArgs: []interface{}{
"instance-id",
},
},
{
expectedStmt: "DELETE FROM projections.quotas_notifications WHERE (instance_id = $1)",
expectedArgs: []interface{}{
"instance-id",
},
},
{
expectedStmt: "DELETE FROM projections.quotas WHERE (instance_id = $1)",
expectedArgs: []interface{}{
"instance-id",
},
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
event := baseEvent(t)
got, err := tt.reduce(event)
if !errors.IsErrorInvalidArgument(err) {
t.Errorf("no wrong event mapping: %v, got: %v", err, got)
}
event = tt.args.event(t)
got, err = tt.reduce(event)
assertReduce(t, got, err, QuotasProjectionTable, tt.want)
})
}
}
func Test_quotaProjection_IncrementUsage(t *testing.T) {
testNow := time.Now()
type fields struct {
client *database.DB
}
type args struct {
ctx context.Context
unit quota.Unit
instanceID string
periodStart time.Time
count uint64
}
type res struct {
sum uint64
err error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
name: "",
fields: fields{
client: func() *database.DB {
db, mock, _ := sqlmock.New()
mock.ExpectQuery(regexp.QuoteMeta(incrementQuotaStatement)).
WithArgs(
"instance_id",
1,
testNow,
2,
).
WillReturnRows(sqlmock.NewRows([]string{"key"}).
AddRow(3))
return &database.DB{DB: db}
}(),
},
args: args{
ctx: context.Background(),
unit: quota.RequestsAllAuthenticated,
instanceID: "instance_id",
periodStart: testNow,
count: 2,
},
res: res{
sum: 3,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := &quotaProjection{
client: tt.fields.client,
}
gotSum, err := q.IncrementUsage(tt.args.ctx, tt.args.unit, tt.args.instanceID, tt.args.periodStart, tt.args.count)
assert.Equal(t, tt.res.sum, gotSum)
assert.ErrorIs(t, err, tt.res.err)
})
}
}

View File

@ -26,6 +26,7 @@ import (
"github.com/zitadel/zitadel/internal/repository/oidcsession"
"github.com/zitadel/zitadel/internal/repository/org"
"github.com/zitadel/zitadel/internal/repository/project"
"github.com/zitadel/zitadel/internal/repository/quota"
"github.com/zitadel/zitadel/internal/repository/session"
usr_repo "github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/repository/usergrant"
@ -93,6 +94,7 @@ func StartQueries(
idpintent.RegisterEventMappers(repo.eventstore)
authrequest.RegisterEventMappers(repo.eventstore)
oidcsession.RegisterEventMappers(repo.eventstore)
quota.RegisterEventMappers(repo.eventstore)
repo.idpConfigEncryption = idpConfigEncryption
repo.multifactors = domain.MultifactorConfigs{

121
internal/query/quota.go Normal file
View File

@ -0,0 +1,121 @@
package query
import (
"context"
"database/sql"
errs "errors"
"time"
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/repository/quota"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
var (
quotasTable = table{
name: projection.QuotasProjectionTable,
instanceIDCol: projection.QuotaColumnInstanceID,
}
QuotaColumnID = Column{
name: projection.QuotaColumnID,
table: quotasTable,
}
QuotaColumnInstanceID = Column{
name: projection.QuotaColumnInstanceID,
table: quotasTable,
}
QuotaColumnUnit = Column{
name: projection.QuotaColumnUnit,
table: quotasTable,
}
QuotaColumnAmount = Column{
name: projection.QuotaColumnAmount,
table: quotasTable,
}
QuotaColumnLimit = Column{
name: projection.QuotaColumnLimit,
table: quotasTable,
}
QuotaColumnInterval = Column{
name: projection.QuotaColumnInterval,
table: quotasTable,
}
QuotaColumnFrom = Column{
name: projection.QuotaColumnFrom,
table: quotasTable,
}
)
type Quota struct {
ID string
From time.Time
ResetInterval time.Duration
Amount uint64
Limit bool
CurrentPeriodStart time.Time
}
func (q *Queries) GetQuota(ctx context.Context, instanceID string, unit quota.Unit) (qu *Quota, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareQuotaQuery(ctx, q.client)
stmt, args, err := query.Where(
sq.Eq{
QuotaColumnInstanceID.identifier(): instanceID,
QuotaColumnUnit.identifier(): unit,
},
).ToSql()
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-XmYn9", "Errors.Query.SQLStatement")
}
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
qu, err = scan(row)
return err
}, stmt, args...)
return qu, err
}
func prepareQuotaQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Quota, error)) {
return sq.
Select(
QuotaColumnID.identifier(),
QuotaColumnFrom.identifier(),
QuotaColumnInterval.identifier(),
QuotaColumnAmount.identifier(),
QuotaColumnLimit.identifier(),
"now()",
).
From(quotasTable.identifier()).
PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*Quota, error) {
q := new(Quota)
var interval database.Duration
var now time.Time
err := row.Scan(&q.ID, &q.From, &interval, &q.Amount, &q.Limit, &now)
if err != nil {
if errs.Is(err, sql.ErrNoRows) {
return nil, errors.ThrowNotFound(err, "QUERY-rDTM6", "Errors.Quota.NotExisting")
}
return nil, errors.ThrowInternal(err, "QUERY-LqySK", "Errors.Internal")
}
q.ResetInterval = time.Duration(interval)
q.CurrentPeriodStart = pushPeriodStart(q.From, q.ResetInterval, now)
return q, nil
}
}
func pushPeriodStart(from time.Time, interval time.Duration, now time.Time) time.Time {
if now.IsZero() {
now = time.Now()
}
for {
next := from.Add(interval)
if next.After(now) {
return from
}
from = next
}
}

View File

@ -1,55 +0,0 @@
package query
import (
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/quota"
)
type quotaReadModel struct {
eventstore.ReadModel
unit quota.Unit
active bool
config *quota.AddedEvent
}
// newQuotaReadModel aggregateId is filled by reducing unit matching events
func newQuotaReadModel(instanceId, resourceOwner string, unit quota.Unit) *quotaReadModel {
return &quotaReadModel{
ReadModel: eventstore.ReadModel{
InstanceID: instanceId,
ResourceOwner: resourceOwner,
},
unit: unit,
}
}
func (rm *quotaReadModel) Query() *eventstore.SearchQueryBuilder {
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
ResourceOwner(rm.ResourceOwner).
AllowTimeTravel().
AddQuery().
InstanceID(rm.InstanceID).
AggregateTypes(quota.AggregateType).
EventTypes(
quota.AddedEventType,
quota.RemovedEventType,
).EventData(map[string]interface{}{"unit": rm.unit})
return query.Builder()
}
func (rm *quotaReadModel) Reduce() error {
for _, event := range rm.Events {
switch e := event.(type) {
case *quota.AddedEvent:
rm.AggregateID = e.Aggregate().ID
rm.active = true
rm.config = e
case *quota.RemovedEvent:
rm.AggregateID = e.Aggregate().ID
rm.active = false
rm.config = nil
}
}
return rm.ReadModel.Reduce()
}

View File

@ -2,58 +2,180 @@ package query
import (
"context"
"database/sql"
errs "errors"
"math"
"time"
"github.com/zitadel/zitadel/internal/eventstore"
sq "github.com/Masterminds/squirrel"
"github.com/pkg/errors"
"github.com/zitadel/zitadel/internal/api/call"
zitadel_errors "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/repository/quota"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
func (q *Queries) GetDueQuotaNotifications(ctx context.Context, config *quota.AddedEvent, periodStart time.Time, usedAbs uint64) ([]*quota.NotificationDueEvent, error) {
if len(config.Notifications) == 0 {
var (
quotaNotificationsTable = table{
name: projection.QuotaNotificationsTable,
instanceIDCol: projection.QuotaNotificationColumnInstanceID,
}
QuotaNotificationColumnInstanceID = Column{
name: projection.QuotaNotificationColumnInstanceID,
table: quotaNotificationsTable,
}
QuotaNotificationColumnUnit = Column{
name: projection.QuotaNotificationColumnUnit,
table: quotaNotificationsTable,
}
QuotaNotificationColumnID = Column{
name: projection.QuotaNotificationColumnID,
table: quotaNotificationsTable,
}
QuotaNotificationColumnCallURL = Column{
name: projection.QuotaNotificationColumnCallURL,
table: quotaNotificationsTable,
}
QuotaNotificationColumnPercent = Column{
name: projection.QuotaNotificationColumnPercent,
table: quotaNotificationsTable,
}
QuotaNotificationColumnRepeat = Column{
name: projection.QuotaNotificationColumnRepeat,
table: quotaNotificationsTable,
}
QuotaNotificationColumnLatestDuePeriodStart = Column{
name: projection.QuotaNotificationColumnLatestDuePeriodStart,
table: quotaNotificationsTable,
}
QuotaNotificationColumnNextDueThreshold = Column{
name: projection.QuotaNotificationColumnNextDueThreshold,
table: quotaNotificationsTable,
}
)
func (q *Queries) GetDueQuotaNotifications(ctx context.Context, instanceID string, unit quota.Unit, qu *Quota, periodStart time.Time, usedAbs uint64) (dueNotifications []*quota.NotificationDueEvent, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
usedRel := uint16(math.Floor(float64(usedAbs*100) / float64(qu.Amount)))
query, scan := prepareQuotaNotificationsQuery(ctx, q.client)
stmt, args, err := query.Where(
sq.And{
sq.Eq{
QuotaNotificationColumnInstanceID.identifier(): instanceID,
QuotaNotificationColumnUnit.identifier(): unit,
},
sq.Or{
// If the relative usage is greater than the next due threshold in the current period, it's clear we can notify
sq.And{
sq.Eq{QuotaNotificationColumnLatestDuePeriodStart.identifier(): periodStart},
sq.LtOrEq{QuotaNotificationColumnNextDueThreshold.identifier(): usedRel},
},
// In case we haven't seen a due notification for this quota period, we compare against the configured percent
sq.And{
sq.Or{
sq.Expr(QuotaNotificationColumnLatestDuePeriodStart.identifier() + " IS NULL"),
sq.NotEq{QuotaNotificationColumnLatestDuePeriodStart.identifier(): periodStart},
},
sq.LtOrEq{QuotaNotificationColumnPercent.identifier(): usedRel},
},
},
},
).ToSql()
if err != nil {
return nil, zitadel_errors.ThrowInternal(err, "QUERY-XmYn9", "Errors.Query.SQLStatement")
}
var notifications *QuotaNotifications
err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
notifications, err = scan(rows)
return err
}, stmt, args...)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
aggregate := config.Aggregate()
wm, err := q.getQuotaNotificationsReadModel(ctx, aggregate, periodStart)
if err != nil {
return nil, err
}
usedRel := uint16(math.Floor(float64(usedAbs*100) / float64(config.Amount)))
var dueNotifications []*quota.NotificationDueEvent
for _, notification := range config.Notifications {
if notification.Percent > usedRel {
for _, notification := range notifications.Configs {
reachedThreshold := calculateThreshold(usedRel, notification.Percent)
if !notification.Repeat && notification.Percent < reachedThreshold {
continue
}
threshold := notification.Percent
if notification.Repeat {
threshold = uint16(math.Max(1, math.Floor(float64(usedRel)/float64(notification.Percent)))) * notification.Percent
}
if wm.latestDueThresholds[notification.ID] < threshold {
dueNotifications = append(
dueNotifications,
quota.NewNotificationDueEvent(
ctx,
&aggregate,
config.Unit,
&quota.NewAggregate(qu.ID, instanceID).Aggregate,
unit,
notification.ID,
notification.CallURL,
periodStart,
threshold,
reachedThreshold,
usedAbs,
),
)
}
}
return dueNotifications, nil
}
func (q *Queries) getQuotaNotificationsReadModel(ctx context.Context, aggregate eventstore.Aggregate, periodStart time.Time) (*quotaNotificationsReadModel, error) {
wm := newQuotaNotificationsReadModel(aggregate.ID, aggregate.InstanceID, aggregate.ResourceOwner, periodStart)
return wm, q.eventstore.FilterToQueryReducer(ctx, wm)
type QuotaNotification struct {
ID string
CallURL string
Percent uint16
Repeat bool
NextDueThreshold uint16
}
type QuotaNotifications struct {
SearchResponse
Configs []*QuotaNotification
}
// calculateThreshold calculates the nearest reached threshold.
// It makes sure that the percent configured on the notification is calculated within the "current" 100%,
// e.g. when configuring 80%, the thresholds are 80, 180, 280, ...
// so 170% use is always 70% of the current bucket, with the above config, the reached threshold would be 80.
func calculateThreshold(usedRel, notificationPercent uint16) uint16 {
// check how many times we reached 100%
times := math.Floor(float64(usedRel) / 100)
// check how many times we reached the percent configured with the "current" 100%
percent := math.Floor(float64(usedRel%100) / float64(notificationPercent))
// If neither is reached, directly return 0.
// This way we don't end up in some wrong uint16 range in the calculation below.
if times == 0 && percent == 0 {
return 0
}
return uint16(times+percent-1)*100 + notificationPercent
}
func prepareQuotaNotificationsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*QuotaNotifications, error)) {
return sq.Select(
QuotaNotificationColumnID.identifier(),
QuotaNotificationColumnCallURL.identifier(),
QuotaNotificationColumnPercent.identifier(),
QuotaNotificationColumnRepeat.identifier(),
QuotaNotificationColumnNextDueThreshold.identifier(),
).
From(quotaNotificationsTable.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*QuotaNotifications, error) {
cfgs := &QuotaNotifications{Configs: []*QuotaNotification{}}
for rows.Next() {
cfg := new(QuotaNotification)
var nextDueThreshold sql.NullInt16
err := rows.Scan(&cfg.ID, &cfg.CallURL, &cfg.Percent, &cfg.Repeat, &nextDueThreshold)
if err != nil {
if errs.Is(err, sql.ErrNoRows) {
return nil, zitadel_errors.ThrowNotFound(err, "QUERY-bbqWb", "Errors.QuotaNotification.NotExisting")
}
return nil, zitadel_errors.ThrowInternal(err, "QUERY-8copS", "Errors.Internal")
}
if nextDueThreshold.Valid {
cfg.NextDueThreshold = uint16(nextDueThreshold.Int16)
}
cfgs.Configs = append(cfgs.Configs, cfg)
}
return cfgs, nil
}
}

View File

@ -1,46 +0,0 @@
package query
import (
"time"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/quota"
)
type quotaNotificationsReadModel struct {
eventstore.ReadModel
periodStart time.Time
latestDueThresholds map[string]uint16
}
func newQuotaNotificationsReadModel(aggregateId, instanceId, resourceOwner string, periodStart time.Time) *quotaNotificationsReadModel {
return &quotaNotificationsReadModel{
ReadModel: eventstore.ReadModel{
AggregateID: aggregateId,
InstanceID: instanceId,
ResourceOwner: resourceOwner,
},
periodStart: periodStart,
latestDueThresholds: make(map[string]uint16),
}
}
func (rm *quotaNotificationsReadModel) Query() *eventstore.SearchQueryBuilder {
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
ResourceOwner(rm.ResourceOwner).
AllowTimeTravel().
AddQuery().
InstanceID(rm.InstanceID).
AggregateTypes(quota.AggregateType).
AggregateIDs(rm.AggregateID).
CreationDateAfter(rm.periodStart).
EventTypes(quota.NotificationDueEventType).Builder()
}
func (rm *quotaNotificationsReadModel) Reduce() error {
for _, event := range rm.Events {
e := event.(*quota.NotificationDueEvent)
rm.latestDueThresholds[e.ID] = e.Threshold
}
return rm.ReadModel.Reduce()
}

View File

@ -0,0 +1,181 @@
package query
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"regexp"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_calculateThreshold(t *testing.T) {
type args struct {
usedRel uint16
notificationPercent uint16
}
tests := []struct {
name string
args args
want uint16
}{
{
name: "80 - below configuration",
args: args{
usedRel: 70,
notificationPercent: 80,
},
want: 0,
},
{
name: "80 - below 100 percent use",
args: args{
usedRel: 90,
notificationPercent: 80,
},
want: 80,
},
{
name: "80 - above 100 percent use",
args: args{
usedRel: 120,
notificationPercent: 80,
},
want: 80,
},
{
name: "80 - more than twice the use",
args: args{
usedRel: 190,
notificationPercent: 80,
},
want: 180,
},
{
name: "100 - below 100 percent use",
args: args{
usedRel: 90,
notificationPercent: 100,
},
want: 0,
},
{
name: "100 - above 100 percent use",
args: args{
usedRel: 120,
notificationPercent: 100,
},
want: 100,
},
{
name: "100 - more than twice the use",
args: args{
usedRel: 210,
notificationPercent: 100,
},
want: 200,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := calculateThreshold(tt.args.usedRel, tt.args.notificationPercent)
assert.Equal(t, int(tt.want), int(got))
})
}
}
var (
expectedQuotaNotificationsQuery = regexp.QuoteMeta(`SELECT projections.quotas_notifications.id,` +
` projections.quotas_notifications.call_url,` +
` projections.quotas_notifications.percent,` +
` projections.quotas_notifications.repeat,` +
` projections.quotas_notifications.next_due_threshold` +
` FROM projections.quotas_notifications` +
` AS OF SYSTEM TIME '-1 ms'`)
quotaNotificationsCols = []string{
"id",
"call_url",
"percent",
"repeat",
"next_due_threshold",
}
)
func Test_prepareQuotaNotificationsQuery(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
err checkErr
}
tests := []struct {
name string
prepare interface{}
want want
object interface{}
}{
{
name: "prepareQuotaNotificationsQuery no result",
prepare: prepareQuotaNotificationsQuery,
want: want{
sqlExpectations: mockQueries(
expectedQuotaNotificationsQuery,
nil,
nil,
),
},
object: &QuotaNotifications{Configs: []*QuotaNotification{}},
},
{
name: "prepareQuotaNotificationsQuery",
prepare: prepareQuotaNotificationsQuery,
want: want{
sqlExpectations: mockQuery(
expectedQuotaNotificationsQuery,
quotaNotificationsCols,
[]driver.Value{
"quota-id",
"url",
uint16(100),
true,
uint16(100),
},
),
},
object: &QuotaNotifications{
Configs: []*QuotaNotification{
{
ID: "quota-id",
CallURL: "url",
Percent: 100,
Repeat: true,
NextDueThreshold: 100,
},
},
},
},
{
name: "prepareQuotaNotificationsQuery sql err",
prepare: prepareQuotaNotificationsQuery,
want: want{
sqlExpectations: mockQueryErr(
expectedQuotaNotificationsQuery,
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
if !errors.Is(err, sql.ErrConnDone) {
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
}
return nil, true
},
},
object: (*Quota)(nil),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}

View File

@ -1,30 +0,0 @@
package query
import (
"context"
"time"
"github.com/zitadel/zitadel/internal/repository/quota"
)
func (q *Queries) GetCurrentQuotaPeriod(ctx context.Context, instanceID string, unit quota.Unit) (*quota.AddedEvent, time.Time, error) {
rm, err := q.getQuotaReadModel(ctx, instanceID, instanceID, unit)
if err != nil || !rm.active {
return nil, time.Time{}, err
}
return rm.config, pushPeriodStart(rm.config.From, rm.config.ResetInterval, time.Now()), nil
}
func pushPeriodStart(from time.Time, interval time.Duration, now time.Time) time.Time {
next := from.Add(interval)
if next.After(now) {
return from
}
return pushPeriodStart(next, interval, now)
}
func (q *Queries) getQuotaReadModel(ctx context.Context, instanceId, resourceOwner string, unit quota.Unit) (*quotaReadModel, error) {
rm := newQuotaReadModel(instanceId, resourceOwner, unit)
return rm, q.eventstore.FilterToQueryReducer(ctx, rm)
}

View File

@ -0,0 +1,86 @@
package query
import (
"context"
"database/sql"
"errors"
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/api/call"
zitadel_errors "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/repository/quota"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
var (
quotaPeriodsTable = table{
name: projection.QuotaPeriodsProjectionTable,
instanceIDCol: projection.QuotaColumnInstanceID,
}
QuotaPeriodColumnInstanceID = Column{
name: projection.QuotaPeriodColumnInstanceID,
table: quotaPeriodsTable,
}
QuotaPeriodColumnUnit = Column{
name: projection.QuotaPeriodColumnUnit,
table: quotaPeriodsTable,
}
QuotaPeriodColumnStart = Column{
name: projection.QuotaPeriodColumnStart,
table: quotaPeriodsTable,
}
QuotaPeriodColumnUsage = Column{
name: projection.QuotaPeriodColumnUsage,
table: quotaPeriodsTable,
}
)
func (q *Queries) GetRemainingQuotaUsage(ctx context.Context, instanceID string, unit quota.Unit) (remaining *uint64, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareRemainingQuotaUsageQuery(ctx, q.client)
query, args, err := stmt.Where(
sq.And{
sq.Eq{
QuotaPeriodColumnInstanceID.identifier(): instanceID,
QuotaPeriodColumnUnit.identifier(): unit,
QuotaColumnLimit.identifier(): true,
},
sq.Expr("age(" + QuotaPeriodColumnStart.identifier() + ") < " + QuotaColumnInterval.identifier()),
sq.Expr(QuotaPeriodColumnStart.identifier() + " < now()"),
}).
ToSql()
if err != nil {
return nil, zitadel_errors.ThrowInternal(err, "QUERY-FSA3g", "Errors.Query.SQLStatement")
}
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
remaining, err = scan(row)
return err
}, query, args...)
if zitadel_errors.IsNotFound(err) {
return nil, nil
}
return remaining, err
}
func prepareRemainingQuotaUsageQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*uint64, error)) {
return sq.
Select(
"greatest(0, " + QuotaColumnAmount.identifier() + "-" + QuotaPeriodColumnUsage.identifier() + ")",
).
From(quotaPeriodsTable.identifier()).
Join(join(QuotaColumnUnit, QuotaPeriodColumnUnit) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*uint64, error) {
usage := new(uint64)
err := row.Scan(usage)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, zitadel_errors.ThrowNotFound(err, "QUERY-quiowi2", "Errors.Internal")
}
return nil, zitadel_errors.ThrowInternal(err, "QUERY-81j1jn2", "Errors.Internal")
}
return usage, nil
}
}

View File

@ -0,0 +1,95 @@
package query
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"regexp"
"testing"
errs "github.com/zitadel/zitadel/internal/errors"
)
var (
expectedRemainingQuotaUsageQuery = regexp.QuoteMeta(`SELECT greatest(0, projections.quotas.amount-projections.quotas_periods.usage)` +
` FROM projections.quotas_periods` +
` JOIN projections.quotas ON projections.quotas_periods.unit = projections.quotas.unit AND projections.quotas_periods.instance_id = projections.quotas.instance_id` +
` AS OF SYSTEM TIME '-1 ms'`)
remainingQuotaUsageCols = []string{
"usage",
}
)
func Test_prepareRemainingQuotaUsageQuery(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
err checkErr
}
tests := []struct {
name string
prepare interface{}
want want
object interface{}
}{
{
name: "prepareRemainingQuotaUsageQuery no result",
prepare: prepareRemainingQuotaUsageQuery,
want: want{
sqlExpectations: mockQueryScanErr(
expectedRemainingQuotaUsageQuery,
nil,
nil,
),
err: func(err error) (error, bool) {
if !errs.IsNotFound(err) {
return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false
}
return nil, true
},
},
object: (*uint64)(nil),
},
{
name: "prepareRemainingQuotaUsageQuery",
prepare: prepareRemainingQuotaUsageQuery,
want: want{
sqlExpectations: mockQuery(
expectedRemainingQuotaUsageQuery,
remainingQuotaUsageCols,
[]driver.Value{
uint64(100),
},
),
},
object: uint64P(100),
},
{
name: "prepareRemainingQuotaUsageQuery sql err",
prepare: prepareRemainingQuotaUsageQuery,
want: want{
sqlExpectations: mockQueryErr(
expectedRemainingQuotaUsageQuery,
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
if !errors.Is(err, sql.ErrConnDone) {
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
}
return nil, true
},
},
object: (*uint64)(nil),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}
func uint64P(i int) *uint64 {
u := uint64(i)
return &u
}

View File

@ -0,0 +1,127 @@
package query
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"regexp"
"testing"
"time"
"github.com/jackc/pgtype"
errs "github.com/zitadel/zitadel/internal/errors"
)
var (
expectedQuotaQuery = regexp.QuoteMeta(`SELECT projections.quotas.id,` +
` projections.quotas.from_anchor,` +
` projections.quotas.interval,` +
` projections.quotas.amount,` +
` projections.quotas.limit_usage,` +
` now()` +
` FROM projections.quotas`)
quotaCols = []string{
"id",
"from_anchor",
"interval",
"amount",
"limit_usage",
"now",
}
)
func dayNow() time.Time {
return time.Now().Truncate(24 * time.Hour)
}
func interval(t *testing.T, src time.Duration) pgtype.Interval {
interval := pgtype.Interval{}
err := interval.Set(src)
if err != nil {
t.Fatal(err)
}
return interval
}
func Test_QuotaPrepare(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
err checkErr
}
tests := []struct {
name string
prepare interface{}
want want
object interface{}
}{
{
name: "prepareQuotaQuery no result",
prepare: prepareQuotaQuery,
want: want{
sqlExpectations: mockQueriesScanErr(
expectedQuotaQuery,
nil,
nil,
),
err: func(err error) (error, bool) {
if !errs.IsNotFound(err) {
return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false
}
return nil, true
},
},
object: (*Quota)(nil),
},
{
name: "prepareQuotaQuery",
prepare: prepareQuotaQuery,
want: want{
sqlExpectations: mockQuery(
expectedQuotaQuery,
quotaCols,
[]driver.Value{
"quota-id",
dayNow(),
interval(t, time.Hour*24),
uint64(1000),
true,
testNow,
},
),
},
object: &Quota{
ID: "quota-id",
From: dayNow(),
ResetInterval: time.Hour * 24,
CurrentPeriodStart: dayNow(),
Amount: 1000,
Limit: true,
},
},
{
name: "prepareQuotaQuery sql err",
prepare: prepareQuotaQuery,
want: want{
sqlExpectations: mockQueryErr(
expectedQuotaQuery,
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
if !errors.Is(err, sql.ErrConnDone) {
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
}
return nil, true
},
},
object: (*Quota)(nil),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}

View File

@ -13,14 +13,14 @@ type Aggregate struct {
eventstore.Aggregate
}
func NewAggregate(id, instanceId, resourceOwner string) *Aggregate {
func NewAggregate(id, instanceId string) *Aggregate {
return &Aggregate{
Aggregate: eventstore.Aggregate{
Type: AggregateType,
Version: AggregateVersion,
ID: id,
InstanceID: instanceId,
ResourceOwner: resourceOwner,
ResourceOwner: instanceId,
},
}
}

View File

@ -15,7 +15,6 @@ type Unit uint
const (
UniqueQuotaNameType = "quota_units"
UniqueQuotaNotificationIDType = "quota_notification"
eventTypePrefix = eventstore.EventType("quota.")
AddedEventType = eventTypePrefix + "added"
NotifiedEventType = eventTypePrefix + "notified"