diff --git a/cmd/defaults.yaml b/cmd/defaults.yaml index debcd90518..c95170e2bd 100644 --- a/cmd/defaults.yaml +++ b/cmd/defaults.yaml @@ -452,54 +452,29 @@ Actions: LogStore: Access: - Database: - # If enabled, all access logs are stored in the database table logstore.access - Enabled: false # ZITADEL_LOGSTORE_ACCESS_DATABASE_ENABLED - # Logs that are older than the keep duration are cleaned up continuously - # 2160h are 90 days, 3 months - Keep: 2160h # ZITADEL_LOGSTORE_ACCESS_DATABASE_KEEP - # CleanupInterval defines the time between cleanup iterations - CleanupInterval: 4h # ZITADEL_LOGSTORE_ACCESS_DATABASE_CLEANUPINTERVAL - # Debouncing enables to asynchronously emit log entries, so the normal execution performance is not impaired - # Log entries are held in memory until one of the conditions MinFrequency or MaxBulkSize meets. - Debounce: - MinFrequency: 2m # ZITADEL_LOGSTORE_ACCESS_DATABASE_DEBOUNCE_MINFREQUENCY - MaxBulkSize: 100 # ZITADEL_LOGSTORE_ACCESS_DATABASE_DEBOUNCE_MAXBULKSIZE Stdout: # If enabled, all access logs are printed to the binary's standard output Enabled: false # ZITADEL_LOGSTORE_ACCESS_STDOUT_ENABLED - # Debouncing enables to asynchronously emit log entries, so the normal execution performance is not impaired - # Log entries are held in memory until one of the conditions MinFrequency or MaxBulkSize meets. - Debounce: - MinFrequency: 0s # ZITADEL_LOGSTORE_ACCESS_STDOUT_DEBOUNCE_MINFREQUENCY - MaxBulkSize: 0 # ZITADEL_LOGSTORE_ACCESS_STDOUT_DEBOUNCE_MAXBULKSIZE Execution: - Database: - # If enabled, all action execution logs are stored in the database table logstore.execution - Enabled: false # ZITADEL_LOGSTORE_EXECUTION_DATABASE_ENABLED - # Logs that are older than the keep duration are cleaned up continuously - # 2160h are 90 days, 3 months - Keep: 2160h # ZITADEL_LOGSTORE_EXECUTION_DATABASE_KEEP - # CleanupInterval defines the time between cleanup iterations - CleanupInterval: 4h # ZITADEL_LOGSTORE_EXECUTION_DATABASE_CLEANUPINTERVAL - # Debouncing enables to asynchronously emit log entries, so the normal execution performance is not impaired - # Log entries are held in memory until one of the conditions MinFrequency or MaxBulkSize meets. - Debounce: - MinFrequency: 0s # ZITADEL_LOGSTORE_EXECUTION_DATABASE_DEBOUNCE_MINFREQUENCY - MaxBulkSize: 0 # ZITADEL_LOGSTORE_EXECUTION_DATABASE_DEBOUNCE_MAXBULKSIZE Stdout: # If enabled, all execution logs are printed to the binary's standard output Enabled: true # ZITADEL_LOGSTORE_EXECUTION_STDOUT_ENABLED - # Debouncing enables to asynchronously emit log entries, so the normal execution performance is not impaired - # Log entries are held in memory until one of the conditions MinFrequency or MaxBulkSize meets. - Debounce: - MinFrequency: 0s # ZITADEL_LOGSTORE_EXECUTION_STDOUT_DEBOUNCE_MINFREQUENCY - MaxBulkSize: 0 # ZITADEL_LOGSTORE_EXECUTION_STDOUT_DEBOUNCE_MAXBULKSIZE Quotas: Access: + # If enabled, authenticated requests are counted and potentially limited depending on the configured quota of the instance + Enabled: false # ZITADEL_QUOTAS_ACCESS_ENABLED + Debounce: + MinFrequency: 0s # ZITADEL_QUOTAS_ACCESS_DEBOUNCE_MINFREQUENCY + MaxBulkSize: 0 # ZITADEL_QUOTAS_ACCESS_DEBOUNCE_MAXBULKSIZE ExhaustedCookieKey: "zitadel.quota.exhausted" # ZITADEL_QUOTAS_ACCESS_EXHAUSTEDCOOKIEKEY ExhaustedCookieMaxAge: "300s" # ZITADEL_QUOTAS_ACCESS_EXHAUSTEDCOOKIEMAXAGE + Execution: + # If enabled, all action executions are counted and potentially limited depending on the configured quota of the instance + Enabled: false # ZITADEL_QUOTAS_EXECUTION_DATABASE_ENABLED + Debounce: + MinFrequency: 0s # ZITADEL_QUOTAS_EXECUTION_DEBOUNCE_MINFREQUENCY + MaxBulkSize: 0 # ZITADEL_QUOTAS_EXECUTION_DEBOUNCE_MAXBULKSIZE Eventstore: PushTimeout: 15s # ZITADEL_EVENTSTORE_PUSHTIMEOUT diff --git a/cmd/start/config.go b/cmd/start/config.go index 0739664aaf..6280fdea40 100644 --- a/cmd/start/config.go +++ b/cmd/start/config.go @@ -72,7 +72,11 @@ type Config struct { } type QuotasConfig struct { - Access *middleware.AccessConfig + Access struct { + logstore.EmitterConfig `mapstructure:",squash"` + middleware.AccessConfig `mapstructure:",squash"` + } + Execution *logstore.EmitterConfig } func MustNewConfig(v *viper.Viper) *Config { diff --git a/cmd/start/start.go b/cmd/start/start.go index f8d69dea51..b7eb8e59c3 100644 --- a/cmd/start/start.go +++ b/cmd/start/start.go @@ -64,6 +64,8 @@ import ( "github.com/zitadel/zitadel/internal/logstore/emitters/access" "github.com/zitadel/zitadel/internal/logstore/emitters/execution" "github.com/zitadel/zitadel/internal/logstore/emitters/stdout" + "github.com/zitadel/zitadel/internal/logstore/record" + "github.com/zitadel/zitadel/internal/net" "github.com/zitadel/zitadel/internal/notification" "github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/static" @@ -108,7 +110,6 @@ type Server struct { AuthzRepo authz_repo.Repository Storage static.Storage Commands *command.Commands - LogStore *logstore.Service Router *mux.Router TLSConfig *tls.Config Shutdown chan<- os.Signal @@ -209,17 +210,16 @@ func startZitadel(config *Config, masterKey string, server chan<- *Server) error } clock := clockpkg.New() - actionsExecutionStdoutEmitter, err := logstore.NewEmitter(ctx, clock, config.LogStore.Execution.Stdout, stdout.NewStdoutEmitter()) + actionsExecutionStdoutEmitter, err := logstore.NewEmitter[*record.ExecutionLog](ctx, clock, &logstore.EmitterConfig{Enabled: config.LogStore.Execution.Stdout.Enabled}, stdout.NewStdoutEmitter[*record.ExecutionLog]()) if err != nil { return err } - actionsExecutionDBEmitter, err := logstore.NewEmitter(ctx, clock, config.LogStore.Execution.Database, execution.NewDatabaseLogStorage(dbClient)) + actionsExecutionDBEmitter, err := logstore.NewEmitter[*record.ExecutionLog](ctx, clock, config.Quotas.Execution, execution.NewDatabaseLogStorage(dbClient, commands, queries)) if err != nil { return err } - usageReporter := logstore.UsageReporterFunc(commands.ReportQuotaUsage) - actionsLogstoreSvc := logstore.New(queries, usageReporter, actionsExecutionDBEmitter, actionsExecutionStdoutEmitter) + actionsLogstoreSvc := logstore.New(queries, actionsExecutionDBEmitter, actionsExecutionStdoutEmitter) actions.SetLogstoreService(actionsLogstoreSvc) notification.Start( @@ -259,8 +259,6 @@ func startZitadel(config *Config, masterKey string, server chan<- *Server) error storage, authZRepo, keys, - queries, - usageReporter, permissionCheck, ) if err != nil { @@ -281,7 +279,6 @@ func startZitadel(config *Config, masterKey string, server chan<- *Server) error AuthzRepo: authZRepo, Storage: storage, Commands: commands, - LogStore: actionsLogstoreSvc, Router: router, TLSConfig: tlsConfig, Shutdown: shutdown, @@ -304,8 +301,6 @@ func startAPIs( store static.Storage, authZRepo authz_repo.Repository, keys *encryptionKeys, - quotaQuerier logstore.QuotaQuerier, - usageReporter logstore.UsageReporter, permissionCheck domain.PermissionCheck, ) error { repo := struct { @@ -321,22 +316,22 @@ func startAPIs( return err } - accessStdoutEmitter, err := logstore.NewEmitter(ctx, clock, config.LogStore.Access.Stdout, stdout.NewStdoutEmitter()) + accessStdoutEmitter, err := logstore.NewEmitter[*record.AccessLog](ctx, clock, &logstore.EmitterConfig{Enabled: config.LogStore.Access.Stdout.Enabled}, stdout.NewStdoutEmitter[*record.AccessLog]()) if err != nil { return err } - accessDBEmitter, err := logstore.NewEmitter(ctx, clock, config.LogStore.Access.Database, access.NewDatabaseLogStorage(dbClient)) + accessDBEmitter, err := logstore.NewEmitter[*record.AccessLog](ctx, clock, &config.Quotas.Access.EmitterConfig, access.NewDatabaseLogStorage(dbClient, commands, queries)) if err != nil { return err } - accessSvc := logstore.New(quotaQuerier, usageReporter, accessDBEmitter, accessStdoutEmitter) + accessSvc := logstore.New[*record.AccessLog](queries, accessDBEmitter, accessStdoutEmitter) exhaustedCookieHandler := http_util.NewCookieHandler( http_util.WithUnsecure(), http_util.WithNonHttpOnly(), http_util.WithMaxAge(int(math.Floor(config.Quotas.Access.ExhaustedCookieMaxAge.Seconds()))), ) - limitingAccessInterceptor := middleware.NewAccessInterceptor(accessSvc, exhaustedCookieHandler, config.Quotas.Access) + limitingAccessInterceptor := middleware.NewAccessInterceptor(accessSvc, exhaustedCookieHandler, &config.Quotas.Access.AccessConfig) apis, err := api.New(ctx, config.Port, router, queries, verifier, config.InternalAuthZ, tlsConfig, config.HTTP2HostHeader, config.HTTP1HostHeader, limitingAccessInterceptor) if err != nil { return fmt.Errorf("error creating api %w", err) @@ -399,13 +394,13 @@ func startAPIs( } apis.RegisterHandlerOnPrefix(openapi.HandlerPrefix, openAPIHandler) - oidcProvider, err := oidc.NewProvider(config.OIDC, login.DefaultLoggedOutPath, config.ExternalSecure, commands, queries, authRepo, keys.OIDC, keys.OIDCKey, eventstore, dbClient, userAgentInterceptor, instanceInterceptor.Handler, limitingAccessInterceptor.Handle) + oidcProvider, err := oidc.NewProvider(config.OIDC, login.DefaultLoggedOutPath, config.ExternalSecure, commands, queries, authRepo, keys.OIDC, keys.OIDCKey, eventstore, dbClient, userAgentInterceptor, instanceInterceptor.Handler, limitingAccessInterceptor) if err != nil { return fmt.Errorf("unable to start oidc provider: %w", err) } apis.RegisterHandlerPrefixes(oidcProvider.HttpHandler(), "/.well-known/openid-configuration", "/oidc/v1", "/oauth/v2") - samlProvider, err := saml.NewProvider(config.SAML, config.ExternalSecure, commands, queries, authRepo, keys.OIDC, keys.SAML, eventstore, dbClient, instanceInterceptor.Handler, userAgentInterceptor, limitingAccessInterceptor.Handle) + samlProvider, err := saml.NewProvider(config.SAML, config.ExternalSecure, commands, queries, authRepo, keys.OIDC, keys.SAML, eventstore, dbClient, instanceInterceptor.Handler, userAgentInterceptor, limitingAccessInterceptor) if err != nil { return fmt.Errorf("unable to start saml provider: %w", err) } @@ -417,7 +412,7 @@ func startAPIs( } apis.RegisterHandlerOnPrefix(console.HandlerPrefix, c) - l, err := login.CreateLogin(config.Login, commands, queries, authRepo, store, console.HandlerPrefix+"/", op.AuthCallbackURL(oidcProvider), provider.AuthCallbackURL(samlProvider), config.ExternalSecure, userAgentInterceptor, op.NewIssuerInterceptor(oidcProvider.IssuerFromRequest).Handler, provider.NewIssuerInterceptor(samlProvider.IssuerFromRequest).Handler, instanceInterceptor.Handler, assetsCache.Handler, limitingAccessInterceptor.Handle, keys.User, keys.IDPConfig, keys.CSRFCookieKey) + l, err := login.CreateLogin(config.Login, commands, queries, authRepo, store, console.HandlerPrefix+"/", op.AuthCallbackURL(oidcProvider), provider.AuthCallbackURL(samlProvider), config.ExternalSecure, userAgentInterceptor, op.NewIssuerInterceptor(oidcProvider.IssuerFromRequest).Handler, provider.NewIssuerInterceptor(samlProvider.IssuerFromRequest).Handler, instanceInterceptor.Handler, assetsCache.Handler, limitingAccessInterceptor.WithoutLimiting().Handle, keys.User, keys.IDPConfig, keys.CSRFCookieKey) if err != nil { return fmt.Errorf("unable to start login: %w", err) } @@ -438,7 +433,7 @@ func listen(ctx context.Context, router *mux.Router, port uint16, tlsConfig *tls http2Server := &http2.Server{} http1Server := &http.Server{Handler: h2c.NewHandler(router, http2Server), TLSConfig: tlsConfig} - lc := listenConfig() + lc := net.ListenConfig() lis, err := lc.Listen(ctx, "tcp", fmt.Sprintf(":%d", port)) if err != nil { return fmt.Errorf("tcp listener on %d failed: %w", port, err) diff --git a/e2e/config/localhost/zitadel.yaml b/e2e/config/localhost/zitadel.yaml index e53061218d..de758680fa 100644 --- a/e2e/config/localhost/zitadel.yaml +++ b/e2e/config/localhost/zitadel.yaml @@ -28,12 +28,7 @@ LogStore: Database: Enabled: true Stdout: - Enabled: false - -Quotas: - Access: - ExhaustedCookieKey: "zitadel.quota.limiting" - ExhaustedCookieMaxAge: "600s" + Enabled: true Console: InstanceManagementURL: "https://example.com/instances/{{.InstanceID}}" @@ -43,9 +38,25 @@ Projections: NotificationsQuotas: RequeueEvery: 1s +Quotas: + Access: + ExhaustedCookieKey: "zitadel.quota.limiting" + ExhaustedCookieMaxAge: "600s" + DefaultInstance: LoginPolicy: MfaInitSkipLifetime: "0" + Quotas: + Items: + - Unit: "actions.all.runs.seconds" + From: "2023-01-01T00:00:00Z" + ResetInterval: 5m + Amount: 20 + Limit: false + Notifications: + - Percent: 100 + Repeat: true + CallURL: "https://httpbin.org/post" SystemAPIUsers: - cypress: diff --git a/e2e/cypress/e2e/quotas/quotas.cy.ts b/e2e/cypress/e2e/quotas/quotas.cy.ts deleted file mode 100644 index a8a3046ce9..0000000000 --- a/e2e/cypress/e2e/quotas/quotas.cy.ts +++ /dev/null @@ -1,324 +0,0 @@ -import { addQuota, ensureQuotaIsAdded, ensureQuotaIsRemoved, removeQuota, Unit } from 'support/api/quota'; -import { createHumanUser, ensureUserDoesntExist } from 'support/api/users'; -import { Context } from 'support/commands'; -import { ZITADELWebhookEvent } from 'support/types'; -import { textChangeRangeIsUnchanged } from 'typescript'; - -beforeEach(() => { - cy.context().as('ctx'); -}); - -describe('quotas', () => { - describe('management', () => { - describe('add one quota', () => { - it('should add a quota only once per unit', () => { - cy.get('@ctx').then((ctx) => { - addQuota(ctx, Unit.AuthenticatedRequests, true, 1); - addQuota(ctx, Unit.AuthenticatedRequests, true, 1, undefined, undefined, undefined, false).then((res) => { - expect(res.status).to.equal(409); - }); - }); - }); - - describe('add two quotas', () => { - it('should add a quota for each unit', () => { - cy.get('@ctx').then((ctx) => { - addQuota(ctx, Unit.AuthenticatedRequests, true, 1); - addQuota(ctx, Unit.ExecutionSeconds, true, 1); - }); - }); - }); - }); - - describe('edit', () => { - describe('remove one quota', () => { - beforeEach(() => { - cy.get('@ctx').then((ctx) => { - ensureQuotaIsAdded(ctx, Unit.AuthenticatedRequests, true, 1); - }); - }); - it('should remove a quota only once per unit', () => { - cy.get('@ctx').then((ctx) => { - removeQuota(ctx, Unit.AuthenticatedRequests); - }); - cy.get('@ctx').then((ctx) => { - removeQuota(ctx, Unit.AuthenticatedRequests, false).then((res) => { - expect(res.status).to.equal(404); - }); - }); - }); - - describe('remove two quotas', () => { - beforeEach(() => { - cy.get('@ctx').then((ctx) => { - ensureQuotaIsAdded(ctx, Unit.AuthenticatedRequests, true, 1); - ensureQuotaIsAdded(ctx, Unit.ExecutionSeconds, true, 1); - }); - }); - it('should remove a quota for each unit', () => { - cy.get('@ctx').then((ctx) => { - removeQuota(ctx, Unit.AuthenticatedRequests); - removeQuota(ctx, Unit.ExecutionSeconds); - }); - }); - }); - }); - }); - }); - - describe('usage', () => { - beforeEach(() => { - cy.get('@ctx') - .then((ctx) => { - return [ - `${ctx.api.oidcBaseURL}/userinfo`, - `${ctx.api.authBaseURL}/users/me`, - `${ctx.api.mgmtBaseURL}/iam`, - `${ctx.api.adminBaseURL}/instances/me`, - `${ctx.api.oauthBaseURL}/keys`, - `${ctx.api.samlBaseURL}/certificate`, - ]; - }) - .as('authenticatedUrls'); - }); - - describe('authenticated requests', () => { - const testUserName = 'shouldNotBeCreated'; - beforeEach(() => { - cy.get>('@authenticatedUrls').then((urls) => { - cy.get('@ctx').then((ctx) => { - ensureUserDoesntExist(ctx.api, testUserName); - ensureQuotaIsAdded(ctx, Unit.AuthenticatedRequests, true, urls.length); - cy.task('runSQL', `TRUNCATE logstore.access;`); - }); - }); - }); - - it('only authenticated requests are limited', () => { - cy.get>('@authenticatedUrls').then((urls) => { - cy.get('@ctx').then((ctx) => { - const start = new Date(); - urls.forEach((url) => { - cy.request({ - url: url, - method: 'GET', - auth: { - bearer: ctx.api.token, - }, - }); - }); - expectCookieDoesntExist(); - const expiresMax = new Date(); - expiresMax.setMinutes(expiresMax.getMinutes() + 20); - cy.request({ - url: urls[1], - method: 'GET', - auth: { - bearer: ctx.api.token, - }, - failOnStatusCode: false, - }).then((res) => { - expect(res.status).to.equal(429); - }); - cy.getCookie('zitadel.quota.limiting').then((cookie) => { - expect(cookie.value).to.equal('true'); - const cookieExpiry = new Date(); - cookieExpiry.setTime(cookie.expiry * 1000); - expect(cookieExpiry).to.be.within(start, expiresMax); - }); - createHumanUser(ctx.api, testUserName, false).then((res) => { - expect(res.status).to.equal(429); - }); - // visit limited console - // cy.visit('/users/me'); - // cy.contains('#authenticated-requests-exhausted-dialog button', 'Continue').click(); - // const upgradeInstancePage = `https://example.com/instances/${ctx.instanceId}`; - // cy.origin(upgradeInstancePage, { args: { upgradeInstancePage } }, ({ upgradeInstancePage }) => { - // cy.location('href').should('equal', upgradeInstancePage); - // }); - // upgrade instance - ensureQuotaIsRemoved(ctx, Unit.AuthenticatedRequests); - // visit upgraded console again - cy.visit('/users/me'); - cy.get('[data-e2e="top-view-title"]'); - expectCookieDoesntExist(); - createHumanUser(ctx.api, testUserName); - expectCookieDoesntExist(); - }); - }); - }); - }); - - describe.skip('notifications', () => { - const callURL = `http://${Cypress.env('WEBHOOK_HANDLER_HOST')}:${Cypress.env('WEBHOOK_HANDLER_PORT')}/do_something`; - - beforeEach(() => cy.task('resetWebhookEvents')); - - const amount = 100; - const percent = 10; - const usage = 35; - - describe('without repetition', () => { - beforeEach(() => { - cy.get('@ctx').then((ctx) => { - ensureQuotaIsAdded(ctx, Unit.AuthenticatedRequests, false, amount, [ - { - callUrl: callURL, - percent: percent, - repeat: false, - }, - ]); - cy.task('runSQL', `TRUNCATE logstore.access;`); - }); - }); - - it('fires at least once with the expected payload', () => { - cy.get>('@authenticatedUrls').then((urls) => { - cy.get('@ctx').then((ctx) => { - for (let i = 0; i < usage; i++) { - cy.request({ - url: urls[0], - method: 'GET', - auth: { - bearer: ctx.api.token, - }, - }); - } - }); - cy.waitUntil( - () => - cy.task>('handledWebhookEvents').then((events) => { - if (events.length < 1) { - return false; - } - return Cypress._.matches({ - sentStatus: 200, - payload: { - callURL: callURL, - threshold: percent, - unit: 1, - usage: percent, - }, - })(events[0]); - }), - { timeout: 60_000 }, - ); - }); - }); - - it('fires until the webhook returns a successful message', () => { - cy.task('failWebhookEvents', 8); - cy.get>('@authenticatedUrls').then((urls) => { - cy.get('@ctx').then((ctx) => { - for (let i = 0; i < usage; i++) { - cy.request({ - url: urls[0], - method: 'GET', - auth: { - bearer: ctx.api.token, - }, - }); - } - }); - cy.waitUntil( - () => - cy.task>('handledWebhookEvents').then((events) => { - if (events.length != 9) { - return false; - } - return events.reduce((a, b, i) => { - return !a - ? a - : i < 8 - ? Cypress._.matches({ - sentStatus: 500, - payload: { - callURL: callURL, - threshold: percent, - unit: 1, - usage: percent, - }, - })(b) - : Cypress._.matches({ - sentStatus: 200, - payload: { - callURL: callURL, - threshold: percent, - unit: 1, - usage: percent, - }, - })(b); - }, true); - }), - { timeout: 60_000 }, - ); - }); - }); - }); - - describe('with repetition', () => { - beforeEach(() => { - cy.get>('@authenticatedUrls').then((urls) => { - cy.get('@ctx').then((ctx) => { - ensureQuotaIsAdded(ctx, Unit.AuthenticatedRequests, false, amount, [ - { - callUrl: callURL, - percent: percent, - repeat: true, - }, - ]); - cy.task('runSQL', `TRUNCATE logstore.access;`); - }); - }); - }); - - it('fires repeatedly with the expected payloads', () => { - cy.get>('@authenticatedUrls').then((urls) => { - cy.get('@ctx').then((ctx) => { - for (let i = 0; i < usage; i++) { - cy.request({ - url: urls[0], - method: 'GET', - auth: { - bearer: ctx.api.token, - }, - }); - } - }); - }); - cy.waitUntil( - () => - cy.task>('handledWebhookEvents').then((events) => { - let foundExpected = 0; - for (let i = 0; i < events.length; i++) { - for (let expect = 10; expect <= 30; expect += 10) { - if ( - Cypress._.matches({ - sentStatus: 200, - payload: { - callURL: callURL, - threshold: expect, - unit: 1, - usage: expect, - }, - })(events[i]) - ) { - foundExpected++; - } - } - } - return foundExpected >= 3; - }), - { timeout: 60_000 }, - ); - }); - }); - }); - }); -}); - -function expectCookieDoesntExist() { - cy.getCookie('zitadel.quota.limiting').then((cookie) => { - expect(cookie).to.be.null; - }); -} diff --git a/internal/actions/actions_test.go b/internal/actions/actions_test.go index 1287546a95..d33b3082ad 100644 --- a/internal/actions/actions_test.go +++ b/internal/actions/actions_test.go @@ -7,11 +7,13 @@ import ( "time" "github.com/dop251/goja" + "github.com/zitadel/zitadel/internal/logstore" + "github.com/zitadel/zitadel/internal/logstore/record" ) func TestRun(t *testing.T) { - SetLogstoreService(logstore.New(nil, nil, nil)) + SetLogstoreService(logstore.New[*record.ExecutionLog](nil, nil)) type args struct { timeout time.Duration api apiFields diff --git a/internal/actions/fields_test.go b/internal/actions/fields_test.go index 93862cfb3d..9f1f8e44d7 100644 --- a/internal/actions/fields_test.go +++ b/internal/actions/fields_test.go @@ -5,11 +5,13 @@ import ( "testing" "github.com/dop251/goja" + "github.com/zitadel/zitadel/internal/logstore" + "github.com/zitadel/zitadel/internal/logstore/record" ) func TestSetFields(t *testing.T) { - SetLogstoreService(logstore.New(nil, nil, nil)) + SetLogstoreService(logstore.New[*record.ExecutionLog](nil, nil)) primitveFn := func(a string) { fmt.Println(a) } complexFn := func(*FieldConfig) interface{} { return primitveFn diff --git a/internal/actions/http_module_test.go b/internal/actions/http_module_test.go index 176f789e2c..1b456d2196 100644 --- a/internal/actions/http_module_test.go +++ b/internal/actions/http_module_test.go @@ -10,12 +10,14 @@ import ( "testing" "github.com/dop251/goja" + "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/logstore" + "github.com/zitadel/zitadel/internal/logstore/record" ) func Test_isHostBlocked(t *testing.T) { - SetLogstoreService(logstore.New(nil, nil, nil)) + SetLogstoreService(logstore.New[*record.ExecutionLog](nil, nil)) var denyList = []AddressChecker{ mustNewIPChecker(t, "192.168.5.0/24"), mustNewIPChecker(t, "127.0.0.1"), diff --git a/internal/actions/log_module.go b/internal/actions/log_module.go index 1d2e28adf4..a7a57a80c8 100644 --- a/internal/actions/log_module.go +++ b/internal/actions/log_module.go @@ -10,15 +10,15 @@ import ( "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/logstore" - "github.com/zitadel/zitadel/internal/logstore/emitters/execution" + "github.com/zitadel/zitadel/internal/logstore/record" ) var ( - logstoreService *logstore.Service + logstoreService *logstore.Service[*record.ExecutionLog] _ console.Printer = (*logger)(nil) ) -func SetLogstoreService(svc *logstore.Service) { +func SetLogstoreService(svc *logstore.Service[*record.ExecutionLog]) { logstoreService = svc } @@ -55,19 +55,16 @@ func (l *logger) log(msg string, level logrus.Level, last bool) { if l.started.IsZero() { l.started = ts } - - record := &execution.Record{ + r := &record.ExecutionLog{ LogDate: ts, InstanceID: l.instanceID, Message: msg, LogLevel: level, } - if last { - record.Took = ts.Sub(l.started) + r.Took = ts.Sub(l.started) } - - logstoreService.Handle(l.ctx, record) + logstoreService.Handle(l.ctx, r) } func withLogger(ctx context.Context) Option { diff --git a/internal/api/grpc/server/middleware/access_interceptor.go b/internal/api/grpc/server/middleware/access_interceptor.go index 2b3cd43df5..719e03da78 100644 --- a/internal/api/grpc/server/middleware/access_interceptor.go +++ b/internal/api/grpc/server/middleware/access_interceptor.go @@ -10,11 +10,11 @@ import ( "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/logstore" - "github.com/zitadel/zitadel/internal/logstore/emitters/access" + "github.com/zitadel/zitadel/internal/logstore/record" "github.com/zitadel/zitadel/internal/telemetry/tracing" ) -func AccessStorageInterceptor(svc *logstore.Service) grpc.UnaryServerInterceptor { +func AccessStorageInterceptor(svc *logstore.Service[*record.AccessLog]) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) { if !svc.Enabled() { return handler(ctx, req) @@ -36,9 +36,9 @@ func AccessStorageInterceptor(svc *logstore.Service) grpc.UnaryServerInterceptor resMd, _ := metadata.FromOutgoingContext(ctx) instance := authz.GetInstance(ctx) - record := &access.Record{ + r := &record.AccessLog{ LogDate: time.Now(), - Protocol: access.GRPC, + Protocol: record.GRPC, RequestURL: info.FullMethod, ResponseStatus: respStatus, RequestHeaders: reqMd, @@ -49,7 +49,7 @@ func AccessStorageInterceptor(svc *logstore.Service) grpc.UnaryServerInterceptor RequestedHost: instance.RequestedHost(), } - svc.Handle(interceptorCtx, record) + svc.Handle(interceptorCtx, r) return resp, handlerErr } } diff --git a/internal/api/grpc/server/middleware/quota_interceptor.go b/internal/api/grpc/server/middleware/quota_interceptor.go index 08ce38907e..cfcdcedb9f 100644 --- a/internal/api/grpc/server/middleware/quota_interceptor.go +++ b/internal/api/grpc/server/middleware/quota_interceptor.go @@ -9,19 +9,16 @@ import ( "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/logstore" + "github.com/zitadel/zitadel/internal/logstore/record" "github.com/zitadel/zitadel/internal/telemetry/tracing" ) -func QuotaExhaustedInterceptor(svc *logstore.Service, ignoreService ...string) grpc.UnaryServerInterceptor { - - prunedIgnoredServices := make([]string, len(ignoreService)) +func QuotaExhaustedInterceptor(svc *logstore.Service[*record.AccessLog], ignoreService ...string) grpc.UnaryServerInterceptor { for idx, service := range ignoreService { if !strings.HasPrefix(service, "/") { - service = "/" + service + ignoreService[idx] = "/" + service } - prunedIgnoredServices[idx] = service } - return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) { if !svc.Enabled() { return handler(ctx, req) @@ -29,7 +26,13 @@ func QuotaExhaustedInterceptor(svc *logstore.Service, ignoreService ...string) g interceptorCtx, span := tracing.NewServerInterceptorSpan(ctx) defer func() { span.EndWithError(err) }() - for _, service := range prunedIgnoredServices { + // The auth interceptor will ensure that only authorized or public requests are allowed. + // So if there's no authorization context, we don't need to check for limitation + if authz.GetCtxData(ctx).IsZero() { + return handler(ctx, req) + } + + for _, service := range ignoreService { if strings.HasPrefix(info.FullMethod, service) { return handler(ctx, req) } diff --git a/internal/api/grpc/server/server.go b/internal/api/grpc/server/server.go index 9252778036..9d7deb28ea 100644 --- a/internal/api/grpc/server/server.go +++ b/internal/api/grpc/server/server.go @@ -12,6 +12,7 @@ import ( grpc_api "github.com/zitadel/zitadel/internal/api/grpc" "github.com/zitadel/zitadel/internal/api/grpc/server/middleware" "github.com/zitadel/zitadel/internal/logstore" + "github.com/zitadel/zitadel/internal/logstore/record" "github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/telemetry/metrics" system_pb "github.com/zitadel/zitadel/pkg/grpc/system" @@ -39,7 +40,7 @@ func CreateServer( queries *query.Queries, hostHeaderName string, tlsConfig *tls.Config, - accessSvc *logstore.Service, + accessSvc *logstore.Service[*record.AccessLog], ) *grpc.Server { metricTypes := []metrics.MetricType{metrics.MetricTypeTotalCount, metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode} serverOptions := []grpc.ServerOption{ @@ -49,14 +50,14 @@ func CreateServer( middleware.DefaultTracingServer(), middleware.MetricsHandler(metricTypes, grpc_api.Probes...), middleware.NoCacheInterceptor(), - middleware.ErrorHandler(), middleware.InstanceInterceptor(queries, hostHeaderName, system_pb.SystemService_ServiceDesc.ServiceName, healthpb.Health_ServiceDesc.ServiceName), middleware.AccessStorageInterceptor(accessSvc), + middleware.ErrorHandler(), middleware.AuthorizationInterceptor(verifier, authConfig), + middleware.QuotaExhaustedInterceptor(accessSvc, system_pb.SystemService_ServiceDesc.ServiceName), middleware.TranslationHandler(), middleware.ValidationHandler(), middleware.ServiceHandler(), - middleware.QuotaExhaustedInterceptor(accessSvc, system_pb.SystemService_ServiceDesc.ServiceName), ), ), } diff --git a/internal/api/grpc/system/quota_integration_test.go b/internal/api/grpc/system/quota_integration_test.go new file mode 100644 index 0000000000..a8c6840c87 --- /dev/null +++ b/internal/api/grpc/system/quota_integration_test.go @@ -0,0 +1,203 @@ +//go:build integration + +package system_test + +import ( + "bytes" + "encoding/json" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/zitadel/zitadel/internal/integration" + "github.com/zitadel/zitadel/internal/repository/quota" + "github.com/zitadel/zitadel/pkg/grpc/admin" + quota_pb "github.com/zitadel/zitadel/pkg/grpc/quota" + "github.com/zitadel/zitadel/pkg/grpc/system" +) + +var callURL = "http://localhost:" + integration.PortQuotaServer + +func TestServer_QuotaNotification_Limit(t *testing.T) { + _, instanceID, iamOwnerCtx := Tester.UseIsolatedInstance(CTX, SystemCTX) + amount := 10 + percent := 50 + percentAmount := amount * percent / 100 + + _, err := Tester.Client.System.AddQuota(SystemCTX, &system.AddQuotaRequest{ + InstanceId: instanceID, + Unit: quota_pb.Unit_UNIT_REQUESTS_ALL_AUTHENTICATED, + From: timestamppb.Now(), + ResetInterval: durationpb.New(time.Minute * 5), + Amount: uint64(amount), + Limit: true, + Notifications: []*quota_pb.Notification{ + { + Percent: uint32(percent), + Repeat: true, + CallUrl: callURL, + }, + { + Percent: 100, + Repeat: true, + CallUrl: callURL, + }, + }, + }) + require.NoError(t, err) + + for i := 0; i < percentAmount; i++ { + _, err := Tester.Client.Admin.GetDefaultOrg(iamOwnerCtx, &admin.GetDefaultOrgRequest{}) + require.NoErrorf(t, err, "error in %d call of %d", i, percentAmount) + } + awaitNotification(t, Tester.QuotaNotificationChan, quota.RequestsAllAuthenticated, percent) + + for i := 0; i < (amount - percentAmount); i++ { + _, err := Tester.Client.Admin.GetDefaultOrg(iamOwnerCtx, &admin.GetDefaultOrgRequest{}) + require.NoErrorf(t, err, "error in %d call of %d", i, percentAmount) + } + awaitNotification(t, Tester.QuotaNotificationChan, quota.RequestsAllAuthenticated, 100) + + _, limitErr := Tester.Client.Admin.GetDefaultOrg(iamOwnerCtx, &admin.GetDefaultOrgRequest{}) + require.Error(t, limitErr) +} + +func TestServer_QuotaNotification_NoLimit(t *testing.T) { + _, instanceID, iamOwnerCtx := Tester.UseIsolatedInstance(CTX, SystemCTX) + amount := 10 + percent := 50 + percentAmount := amount * percent / 100 + + _, err := Tester.Client.System.AddQuota(SystemCTX, &system.AddQuotaRequest{ + InstanceId: instanceID, + Unit: quota_pb.Unit_UNIT_REQUESTS_ALL_AUTHENTICATED, + From: timestamppb.Now(), + ResetInterval: durationpb.New(time.Minute * 5), + Amount: uint64(amount), + Limit: false, + Notifications: []*quota_pb.Notification{ + { + Percent: uint32(percent), + Repeat: false, + CallUrl: callURL, + }, + { + Percent: 100, + Repeat: true, + CallUrl: callURL, + }, + }, + }) + require.NoError(t, err) + + for i := 0; i < percentAmount; i++ { + _, err := Tester.Client.Admin.GetDefaultOrg(iamOwnerCtx, &admin.GetDefaultOrgRequest{}) + require.NoErrorf(t, err, "error in %d call of %d", i, percentAmount) + } + awaitNotification(t, Tester.QuotaNotificationChan, quota.RequestsAllAuthenticated, percent) + + for i := 0; i < (amount - percentAmount); i++ { + _, err := Tester.Client.Admin.GetDefaultOrg(iamOwnerCtx, &admin.GetDefaultOrgRequest{}) + require.NoErrorf(t, err, "error in %d call of %d", i, percentAmount) + } + awaitNotification(t, Tester.QuotaNotificationChan, quota.RequestsAllAuthenticated, 100) + + for i := 0; i < amount; i++ { + _, err := Tester.Client.Admin.GetDefaultOrg(iamOwnerCtx, &admin.GetDefaultOrgRequest{}) + require.NoErrorf(t, err, "error in %d call of %d", i, percentAmount) + } + awaitNotification(t, Tester.QuotaNotificationChan, quota.RequestsAllAuthenticated, 200) + + _, limitErr := Tester.Client.Admin.GetDefaultOrg(iamOwnerCtx, &admin.GetDefaultOrgRequest{}) + require.NoError(t, limitErr) +} + +func awaitNotification(t *testing.T, bodies chan []byte, unit quota.Unit, percent int) { + for { + select { + case body := <-bodies: + plain := new(bytes.Buffer) + if err := json.Indent(plain, body, "", " "); err != nil { + t.Fatal(err) + } + t.Log("received notificationDueEvent", plain.String()) + event := struct { + Unit quota.Unit `json:"unit"` + ID string `json:"id"` + CallURL string `json:"callURL"` + PeriodStart time.Time `json:"periodStart"` + Threshold uint16 `json:"threshold"` + Usage uint64 `json:"usage"` + }{} + if err := json.Unmarshal(body, &event); err != nil { + t.Error(err) + } + if event.ID == "" { + continue + } + if event.Unit == unit && event.Threshold == uint16(percent) { + return + } + case <-time.After(60 * time.Second): + t.Fatalf("timed out waiting for unit %s and percent %d", strconv.Itoa(int(unit)), percent) + } + } +} + +func TestServer_AddAndRemoveQuota(t *testing.T) { + _, instanceID, _ := Tester.UseIsolatedInstance(CTX, SystemCTX) + + got, err := Tester.Client.System.AddQuota(SystemCTX, &system.AddQuotaRequest{ + InstanceId: instanceID, + Unit: quota_pb.Unit_UNIT_REQUESTS_ALL_AUTHENTICATED, + From: timestamppb.Now(), + ResetInterval: durationpb.New(time.Minute), + Amount: 10, + Limit: true, + Notifications: []*quota_pb.Notification{ + { + Percent: 20, + Repeat: true, + CallUrl: callURL, + }, + }, + }) + require.NoError(t, err) + require.Equal(t, got.Details.ResourceOwner, instanceID) + + gotAlreadyExisting, errAlreadyExisting := Tester.Client.System.AddQuota(SystemCTX, &system.AddQuotaRequest{ + InstanceId: instanceID, + Unit: quota_pb.Unit_UNIT_REQUESTS_ALL_AUTHENTICATED, + From: timestamppb.Now(), + ResetInterval: durationpb.New(time.Minute), + Amount: 10, + Limit: true, + Notifications: []*quota_pb.Notification{ + { + Percent: 20, + Repeat: true, + CallUrl: callURL, + }, + }, + }) + require.Error(t, errAlreadyExisting) + require.Nil(t, gotAlreadyExisting) + + gotRemove, errRemove := Tester.Client.System.RemoveQuota(SystemCTX, &system.RemoveQuotaRequest{ + InstanceId: instanceID, + Unit: quota_pb.Unit_UNIT_REQUESTS_ALL_AUTHENTICATED, + }) + require.NoError(t, errRemove) + require.Equal(t, gotRemove.Details.ResourceOwner, instanceID) + + gotRemoveAlready, errRemoveAlready := Tester.Client.System.RemoveQuota(SystemCTX, &system.RemoveQuotaRequest{ + InstanceId: instanceID, + Unit: quota_pb.Unit_UNIT_REQUESTS_ALL_AUTHENTICATED, + }) + require.Error(t, errRemoveAlready) + require.Nil(t, gotRemoveAlready) +} diff --git a/internal/api/grpc/system/server_integration_test.go b/internal/api/grpc/system/server_integration_test.go new file mode 100644 index 0000000000..f36972993f --- /dev/null +++ b/internal/api/grpc/system/server_integration_test.go @@ -0,0 +1,32 @@ +//go:build integration + +package system_test + +import ( + "context" + "os" + "testing" + "time" + + "github.com/zitadel/zitadel/internal/integration" +) + +var ( + CTX context.Context + SystemCTX context.Context + Tester *integration.Tester +) + +func TestMain(m *testing.M) { + os.Exit(func() int { + ctx, _, cancel := integration.Contexts(5 * time.Minute) + defer cancel() + CTX = ctx + + Tester = integration.NewTester(ctx) + defer Tester.Done() + + SystemCTX = Tester.WithAuthorization(ctx, integration.SystemUser) + return m.Run() + }()) +} diff --git a/internal/api/http/middleware/access_interceptor.go b/internal/api/http/middleware/access_interceptor.go index 7d05619ecc..538ec226d9 100644 --- a/internal/api/http/middleware/access_interceptor.go +++ b/internal/api/http/middleware/access_interceptor.go @@ -14,12 +14,12 @@ import ( "github.com/zitadel/zitadel/internal/api/grpc/server/middleware" http_utils "github.com/zitadel/zitadel/internal/api/http" "github.com/zitadel/zitadel/internal/logstore" - "github.com/zitadel/zitadel/internal/logstore/emitters/access" + "github.com/zitadel/zitadel/internal/logstore/record" "github.com/zitadel/zitadel/internal/telemetry/tracing" ) type AccessInterceptor struct { - svc *logstore.Service + svc *logstore.Service[*record.AccessLog] cookieHandler *http_utils.CookieHandler limitConfig *AccessConfig storeOnly bool @@ -33,7 +33,7 @@ type AccessConfig struct { // NewAccessInterceptor intercepts all requests and stores them to the logstore. // If storeOnly is false, it also checks if requests are exhausted. // If requests are exhausted, it also returns http.StatusTooManyRequests and sets a cookie -func NewAccessInterceptor(svc *logstore.Service, cookieHandler *http_utils.CookieHandler, cookieConfig *AccessConfig) *AccessInterceptor { +func NewAccessInterceptor(svc *logstore.Service[*record.AccessLog], cookieHandler *http_utils.CookieHandler, cookieConfig *AccessConfig) *AccessInterceptor { return &AccessInterceptor{ svc: svc, cookieHandler: cookieHandler, @@ -50,7 +50,7 @@ func (a *AccessInterceptor) WithoutLimiting() *AccessInterceptor { } } -func (a *AccessInterceptor) AccessService() *logstore.Service { +func (a *AccessInterceptor) AccessService() *logstore.Service[*record.AccessLog] { return a.svc } @@ -81,46 +81,70 @@ func (a *AccessInterceptor) DeleteExhaustedCookie(writer http.ResponseWriter) { a.cookieHandler.DeleteCookie(writer, a.limitConfig.ExhaustedCookieKey) } +func (a *AccessInterceptor) HandleIgnorePathPrefixes(ignoredPathPrefixes []string) func(next http.Handler) http.Handler { + return a.handle(ignoredPathPrefixes...) +} + func (a *AccessInterceptor) Handle(next http.Handler) http.Handler { - if !a.svc.Enabled() { - return next - } - return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - ctx := request.Context() - tracingCtx, checkSpan := tracing.NewNamedSpan(ctx, "checkAccess") - wrappedWriter := &statusRecorder{ResponseWriter: writer, status: 0} - limited := a.Limit(tracingCtx) - checkSpan.End() - if limited { - a.SetExhaustedCookie(wrappedWriter, request) - http.Error(wrappedWriter, "quota for authenticated requests is exhausted", http.StatusTooManyRequests) + return a.handle()(next) +} + +func (a *AccessInterceptor) handle(ignoredPathPrefixes ...string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + if !a.svc.Enabled() { + return next } - if !limited && !a.storeOnly { - a.DeleteExhaustedCookie(wrappedWriter) - } - if !limited { - next.ServeHTTP(wrappedWriter, request) - } - tracingCtx, writeSpan := tracing.NewNamedSpan(tracingCtx, "writeAccess") - defer writeSpan.End() - requestURL := request.RequestURI - unescapedURL, err := url.QueryUnescape(requestURL) - if err != nil { - logging.WithError(err).WithField("url", requestURL).Warning("failed to unescape request url") - } - instance := authz.GetInstance(tracingCtx) - a.svc.Handle(tracingCtx, &access.Record{ - LogDate: time.Now(), - Protocol: access.HTTP, - RequestURL: unescapedURL, - ResponseStatus: uint32(wrappedWriter.status), - RequestHeaders: request.Header, - ResponseHeaders: writer.Header(), - InstanceID: instance.InstanceID(), - ProjectID: instance.ProjectID(), - RequestedDomain: instance.RequestedDomain(), - RequestedHost: instance.RequestedHost(), + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + ctx := request.Context() + tracingCtx, checkSpan := tracing.NewNamedSpan(ctx, "checkAccessQuota") + wrappedWriter := &statusRecorder{ResponseWriter: writer, status: 0} + for _, ignoredPathPrefix := range ignoredPathPrefixes { + if !strings.HasPrefix(request.RequestURI, ignoredPathPrefix) { + continue + } + checkSpan.End() + next.ServeHTTP(wrappedWriter, request) + a.writeLog(tracingCtx, wrappedWriter, writer, request, true) + return + } + limited := a.Limit(tracingCtx) + checkSpan.End() + if limited { + a.SetExhaustedCookie(wrappedWriter, request) + http.Error(wrappedWriter, "quota for authenticated requests is exhausted", http.StatusTooManyRequests) + } + if !limited && !a.storeOnly { + a.DeleteExhaustedCookie(wrappedWriter) + } + if !limited { + next.ServeHTTP(wrappedWriter, request) + } + a.writeLog(tracingCtx, wrappedWriter, writer, request, a.storeOnly) }) + } +} + +func (a *AccessInterceptor) writeLog(ctx context.Context, wrappedWriter *statusRecorder, writer http.ResponseWriter, request *http.Request, notCountable bool) { + ctx, writeSpan := tracing.NewNamedSpan(ctx, "writeAccess") + defer writeSpan.End() + requestURL := request.RequestURI + unescapedURL, err := url.QueryUnescape(requestURL) + if err != nil { + logging.WithError(err).WithField("url", requestURL).Warning("failed to unescape request url") + } + instance := authz.GetInstance(ctx) + a.svc.Handle(ctx, &record.AccessLog{ + LogDate: time.Now(), + Protocol: record.HTTP, + RequestURL: unescapedURL, + ResponseStatus: uint32(wrappedWriter.status), + RequestHeaders: request.Header, + ResponseHeaders: writer.Header(), + InstanceID: instance.InstanceID(), + ProjectID: instance.ProjectID(), + RequestedDomain: instance.RequestedDomain(), + RequestedHost: instance.RequestedHost(), + NotCountable: notCountable, }) } diff --git a/internal/api/oidc/op.go b/internal/api/oidc/op.go index 62dad091ec..ea630482c7 100644 --- a/internal/api/oidc/op.go +++ b/internal/api/oidc/op.go @@ -7,6 +7,7 @@ import ( "time" "github.com/rakyll/statik/fs" + "github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v2/pkg/op" "golang.org/x/text/language" @@ -79,13 +80,32 @@ type OPStorage struct { assetAPIPrefix func(ctx context.Context) string } -func NewProvider(config Config, defaultLogoutRedirectURI string, externalSecure bool, command *command.Commands, query *query.Queries, repo repository.Repository, encryptionAlg crypto.EncryptionAlgorithm, cryptoKey []byte, es *eventstore.Eventstore, projections *database.DB, userAgentCookie, instanceHandler, accessHandler func(http.Handler) http.Handler) (op.OpenIDProvider, error) { +func NewProvider( + config Config, + defaultLogoutRedirectURI string, + externalSecure bool, + command *command.Commands, + query *query.Queries, + repo repository.Repository, + encryptionAlg crypto.EncryptionAlgorithm, + cryptoKey []byte, + es *eventstore.Eventstore, + projections *database.DB, + userAgentCookie, instanceHandler func(http.Handler) http.Handler, + accessHandler *middleware.AccessInterceptor, +) (op.OpenIDProvider, error) { opConfig, err := createOPConfig(config, defaultLogoutRedirectURI, cryptoKey) if err != nil { return nil, caos_errs.ThrowInternal(err, "OIDC-EGrqd", "cannot create op config: %w") } storage := newStorage(config, command, query, repo, encryptionAlg, es, projections, externalSecure) - options, err := createOptions(config, externalSecure, userAgentCookie, instanceHandler, accessHandler) + options, err := createOptions( + config, + externalSecure, + userAgentCookie, + instanceHandler, + accessHandler.HandleIgnorePathPrefixes(ignoredQuotaLimitEndpoint(config.CustomEndpoints)), + ) if err != nil { return nil, caos_errs.ThrowInternal(err, "OIDC-D3gq1", "cannot create options: %w") } @@ -101,6 +121,21 @@ func NewProvider(config Config, defaultLogoutRedirectURI string, externalSecure return provider, nil } +func ignoredQuotaLimitEndpoint(endpoints *EndpointConfig) []string { + authURL := op.DefaultEndpoints.Authorization.Relative() + keysURL := op.DefaultEndpoints.JwksURI.Relative() + if endpoints == nil { + return []string{oidc.DiscoveryEndpoint, authURL, keysURL} + } + if endpoints.Auth != nil && endpoints.Auth.Path != "" { + authURL = endpoints.Auth.Path + } + if endpoints.Keys != nil && endpoints.Keys.Path != "" { + keysURL = endpoints.Keys.Path + } + return []string{oidc.DiscoveryEndpoint, authURL, keysURL} +} + func createOPConfig(config Config, defaultLogoutRedirectURI string, cryptoKey []byte) (*op.Config, error) { supportedLanguages, err := getSupportedLanguages() if err != nil { diff --git a/internal/api/saml/provider.go b/internal/api/saml/provider.go index f46802f6aa..644fe2886a 100644 --- a/internal/api/saml/provider.go +++ b/internal/api/saml/provider.go @@ -38,8 +38,8 @@ func NewProvider( es *eventstore.Eventstore, projections *database.DB, instanceHandler, - userAgentCookie, - accessHandler func(http.Handler) http.Handler, + userAgentCookie func(http.Handler) http.Handler, + accessHandler *middleware.AccessInterceptor, ) (*provider.Provider, error) { metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount} @@ -63,7 +63,7 @@ func NewProvider( middleware.NoCacheInterceptor().Handler, instanceHandler, userAgentCookie, - accessHandler, + accessHandler.HandleIgnorePathPrefixes(ignoredQuotaLimitEndpoint(conf.ProviderConfig)), http_utils.CopyHeadersToContext, ), provider.WithCustomTimeFormat("2006-01-02T15:04:05.999Z"), @@ -100,3 +100,22 @@ func newStorage( defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID), }, nil } + +func ignoredQuotaLimitEndpoint(config *provider.Config) []string { + metadataEndpoint := HandlerPrefix + provider.DefaultMetadataEndpoint + certificateEndpoint := HandlerPrefix + provider.DefaultCertificateEndpoint + ssoEndpoint := HandlerPrefix + provider.DefaultSingleSignOnEndpoint + if config.MetadataConfig != nil && config.MetadataConfig.Path != "" { + metadataEndpoint = HandlerPrefix + config.MetadataConfig.Path + } + if config.IDPConfig == nil || config.IDPConfig.Endpoints == nil { + return []string{metadataEndpoint, certificateEndpoint, ssoEndpoint} + } + if config.IDPConfig.Endpoints.Certificate != nil && config.IDPConfig.Endpoints.Certificate.Relative() != "" { + certificateEndpoint = HandlerPrefix + config.IDPConfig.Endpoints.Certificate.Relative() + } + if config.IDPConfig.Endpoints.SingleSignOn != nil && config.IDPConfig.Endpoints.SingleSignOn.Relative() != "" { + ssoEndpoint = HandlerPrefix + config.IDPConfig.Endpoints.SingleSignOn.Relative() + } + return []string{metadataEndpoint, certificateEndpoint, ssoEndpoint} +} diff --git a/internal/command/instance.go b/internal/command/instance.go index e734988a5a..18c13b6d40 100644 --- a/internal/command/instance.go +++ b/internal/command/instance.go @@ -283,10 +283,7 @@ func (c *Commands) SetUpInstance(ctx context.Context, setup *InstanceSetup) (str if err != nil { return "", "", nil, nil, err } - - quotaAggregate := quota.NewAggregate(quotaId, instanceID, instanceID) - - validations = append(validations, c.AddQuotaCommand(quotaAggregate, q)) + validations = append(validations, c.AddQuotaCommand(quota.NewAggregate(quotaId, instanceID), q)) } } diff --git a/internal/command/main_test.go b/internal/command/main_test.go index d1a2b3de55..2d798a40e8 100644 --- a/internal/command/main_test.go +++ b/internal/command/main_test.go @@ -26,6 +26,7 @@ import ( "github.com/zitadel/zitadel/internal/repository/oidcsession" "github.com/zitadel/zitadel/internal/repository/org" proj_repo "github.com/zitadel/zitadel/internal/repository/project" + quota_repo "github.com/zitadel/zitadel/internal/repository/quota" "github.com/zitadel/zitadel/internal/repository/session" usr_repo "github.com/zitadel/zitadel/internal/repository/user" "github.com/zitadel/zitadel/internal/repository/usergrant" @@ -50,6 +51,7 @@ func eventstoreExpect(t *testing.T, expects ...expect) *eventstore.Eventstore { idpintent.RegisterEventMappers(es) authrequest.RegisterEventMappers(es) oidcsession.RegisterEventMappers(es) + quota_repo.RegisterEventMappers(es) return es } diff --git a/internal/command/quota.go b/internal/command/quota.go index 95e4d6fa80..b787929104 100644 --- a/internal/command/quota.go +++ b/internal/command/quota.go @@ -21,8 +21,8 @@ const ( QuotaActionsAllRunsSeconds QuotaUnit = "actions.all.runs.seconds" ) -func (q *QuotaUnit) Enum() quota.Unit { - switch *q { +func (q QuotaUnit) Enum() quota.Unit { + switch q { case QuotaRequestsAllAuthenticated: return quota.RequestsAllAuthenticated case QuotaActionsAllRunsSeconds: @@ -46,14 +46,11 @@ func (c *Commands) AddQuota( if wm.active { return nil, errors.ThrowAlreadyExists(nil, "COMMAND-WDfFf", "Errors.Quota.AlreadyExists") } - aggregateId, err := c.idGenerator.Next() if err != nil { return nil, err } - aggregate := quota.NewAggregate(aggregateId, instanceId, instanceId) - - cmds, err := preparation.PrepareCommands(ctx, c.eventstore.Filter, c.AddQuotaCommand(aggregate, q)) + cmds, err := preparation.PrepareCommands(ctx, c.eventstore.Filter, c.AddQuotaCommand(quota.NewAggregate(aggregateId, instanceId), q)) if err != nil { return nil, err } @@ -80,7 +77,7 @@ func (c *Commands) RemoveQuota(ctx context.Context, unit QuotaUnit) (*domain.Obj return nil, errors.ThrowNotFound(nil, "COMMAND-WDfFf", "Errors.Quota.NotFound") } - aggregate := quota.NewAggregate(wm.AggregateID, instanceId, instanceId) + aggregate := quota.NewAggregate(wm.AggregateID, instanceId) events := []eventstore.Command{ quota.NewRemovedEvent(ctx, &aggregate.Aggregate, unit.Enum()), @@ -109,6 +106,22 @@ type QuotaNotification struct { type QuotaNotifications []*QuotaNotification +func (q *QuotaNotification) validate() error { + u, err := url.Parse(q.CallURL) + if err != nil { + return errors.ThrowInvalidArgument(err, "QUOTA-bZ0Fj", "Errors.Quota.Invalid.CallURL") + } + + if !u.IsAbs() || u.Host == "" { + return errors.ThrowInvalidArgument(nil, "QUOTA-HAYmN", "Errors.Quota.Invalid.CallURL") + } + + if q.Percent < 1 { + return errors.ThrowInvalidArgument(nil, "QUOTA-pBfjq", "Errors.Quota.Invalid.Percent") + } + return nil +} + func (q *QuotaNotifications) toAddedEventNotifications(idGenerator id.Generator) ([]*quota.AddedEventNotification, error) { if q == nil { return nil, nil @@ -144,17 +157,8 @@ type AddQuota struct { func (q *AddQuota) validate() error { for _, notification := range q.Notifications { - u, err := url.Parse(notification.CallURL) - if err != nil { - return errors.ThrowInvalidArgument(err, "QUOTA-bZ0Fj", "Errors.Quota.Invalid.CallURL") - } - - if !u.IsAbs() || u.Host == "" { - return errors.ThrowInvalidArgument(nil, "QUOTA-HAYmN", "Errors.Quota.Invalid.CallURL") - } - - if notification.Percent < 1 { - return errors.ThrowInvalidArgument(nil, "QUOTA-pBfjq", "Errors.Quota.Invalid.Percent") + if err := notification.validate(); err != nil { + return err } } @@ -169,11 +173,6 @@ func (q *AddQuota) validate() error { if q.ResetInterval < time.Minute { return errors.ThrowInvalidArgument(nil, "QUOTA-R5otd", "Errors.Quota.Invalid.ResetInterval") } - - if !q.Limit && len(q.Notifications) == 0 { - return errors.ThrowInvalidArgument(nil, "QUOTA-4Nv68", "Errors.Quota.Invalid.Noop") - } - return nil } diff --git a/internal/command/quota_model.go b/internal/command/quota_model.go index 3da2598934..6e05b4df0b 100644 --- a/internal/command/quota_model.go +++ b/internal/command/quota_model.go @@ -41,6 +41,7 @@ func (wm *quotaWriteModel) Reduce() error { switch e := event.(type) { case *quota.AddedEvent: wm.AggregateID = e.Aggregate().ID + wm.ChangeDate = e.CreationDate() wm.active = true case *quota.RemovedEvent: wm.AggregateID = e.Aggregate().ID diff --git a/internal/command/quota_report.go b/internal/command/quota_report.go index 19855452ed..b3fe9afd3e 100644 --- a/internal/command/quota_report.go +++ b/internal/command/quota_report.go @@ -5,16 +5,47 @@ import ( "github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/repository/quota" + "github.com/zitadel/zitadel/internal/telemetry/tracing" ) // ReportQuotaUsage writes a slice of *quota.NotificationDueEvent directly to the eventstore -func (c *Commands) ReportQuotaUsage(ctx context.Context, dueNotifications []*quota.NotificationDueEvent) error { - cmds := make([]eventstore.Command, len(dueNotifications)) - for idx, notification := range dueNotifications { - cmds[idx] = notification +func (c *Commands) ReportQuotaUsage(ctx context.Context, dueNotifications []*quota.NotificationDueEvent) (err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + cmds := make([]eventstore.Command, 0, len(dueNotifications)) + for _, notification := range dueNotifications { + ctxFilter, spanFilter := tracing.NewNamedSpan(ctx, "filterNotificationDueEvents") + events, errFilter := c.eventstore.Filter( + ctxFilter, + eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). + InstanceID(notification.Aggregate().InstanceID). + AddQuery(). + AggregateTypes(quota.AggregateType). + AggregateIDs(notification.Aggregate().ID). + EventTypes(quota.NotificationDueEventType). + EventData(map[string]interface{}{ + "id": notification.ID, + "periodStart": notification.PeriodStart, + "threshold": notification.Threshold, + }).Builder(), + ) + spanFilter.EndWithError(errFilter) + if errFilter != nil { + return errFilter + } + if len(events) > 0 { + continue + } + cmds = append(cmds, notification) } - _, err := c.eventstore.Push(ctx, cmds...) - return err + if len(cmds) == 0 { + return nil + } + ctxPush, spanPush := tracing.NewNamedSpan(ctx, "pushNotificationDueEvents") + _, errPush := c.eventstore.Push(ctxPush, cmds...) + spanPush.EndWithError(errPush) + return errPush } func (c *Commands) UsageNotificationSent(ctx context.Context, dueEvent *quota.NotificationDueEvent) error { diff --git a/internal/command/quota_report_test.go b/internal/command/quota_report_test.go new file mode 100644 index 0000000000..214b67ac7c --- /dev/null +++ b/internal/command/quota_report_test.go @@ -0,0 +1,307 @@ +package command + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/eventstore/repository" + "github.com/zitadel/zitadel/internal/id" + id_mock "github.com/zitadel/zitadel/internal/id/mock" + "github.com/zitadel/zitadel/internal/repository/quota" +) + +func TestQuotaReport_ReportQuotaUsage(t *testing.T) { + type fields struct { + eventstore *eventstore.Eventstore + } + type args struct { + ctx context.Context + dueNotifications []*quota.NotificationDueEvent + } + type res struct { + err func(error) bool + } + tests := []struct { + name string + fields fields + args args + res res + }{ + { + name: "no due events", + fields: fields{ + eventstore: eventstoreExpect( + t, + ), + }, + args: args{ + ctx: authz.WithInstanceID(context.Background(), "INSTANCE"), + }, + res: res{}, + }, + { + name: "due event already reported", + fields: fields{ + eventstore: eventstoreExpect( + t, + expectFilter( + eventFromEventPusherWithInstanceID( + "INSTANCE", + quota.NewNotificationDueEvent(context.Background(), + "a.NewAggregate("quota1", "INSTANCE").Aggregate, + QuotaRequestsAllAuthenticated.Enum(), + "id", + "url", + time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC), + 1000, + 200, + ), + ), + ), + ), + }, + args: args{ + ctx: authz.WithInstanceID(context.Background(), "INSTANCE"), + dueNotifications: []*quota.NotificationDueEvent{ + { + Unit: QuotaRequestsAllAuthenticated.Enum(), + ID: "id", + CallURL: "url", + PeriodStart: time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC), + Threshold: 1000, + Usage: 250, + }, + }, + }, + res: res{}, + }, + { + name: "due event not reported", + fields: fields{ + eventstore: eventstoreExpect( + t, + expectFilter(), + expectPush( + []*repository.Event{ + eventFromEventPusherWithInstanceID( + "INSTANCE", + quota.NewNotificationDueEvent(context.Background(), + "a.NewAggregate("quota1", "INSTANCE").Aggregate, + QuotaRequestsAllAuthenticated.Enum(), + "id", + "url", + time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC), + 1000, + 250, + ), + ), + }, + ), + ), + }, + args: args{ + ctx: authz.WithInstanceID(context.Background(), "INSTANCE"), + dueNotifications: []*quota.NotificationDueEvent{ + quota.NewNotificationDueEvent( + context.Background(), + "a.NewAggregate("quota1", "INSTANCE").Aggregate, + QuotaRequestsAllAuthenticated.Enum(), + "id", + "url", + time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC), + 1000, + 250, + ), + }, + }, + res: res{}, + }, + { + name: "due events", + fields: fields{ + eventstore: eventstoreExpect( + t, + expectFilter(), + expectFilter( + eventFromEventPusherWithInstanceID( + "INSTANCE", + quota.NewNotificationDueEvent(context.Background(), + "a.NewAggregate("quota1", "INSTANCE").Aggregate, + QuotaRequestsAllAuthenticated.Enum(), + "id2", + "url", + time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC), + 1000, + 250, + ), + ), + ), + expectFilter(), + expectPush( + []*repository.Event{ + eventFromEventPusherWithInstanceID( + "INSTANCE", + quota.NewNotificationDueEvent(context.Background(), + "a.NewAggregate("quota1", "INSTANCE").Aggregate, + QuotaRequestsAllAuthenticated.Enum(), + "id1", + "url", + time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC), + 1000, + 250, + ), + ), + eventFromEventPusherWithInstanceID( + "INSTANCE", + quota.NewNotificationDueEvent(context.Background(), + "a.NewAggregate("quota1", "INSTANCE").Aggregate, + QuotaRequestsAllAuthenticated.Enum(), + "id3", + "url", + time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC), + 1000, + 250, + ), + ), + }, + ), + ), + }, + args: args{ + ctx: authz.WithInstanceID(context.Background(), "INSTANCE"), + dueNotifications: []*quota.NotificationDueEvent{ + quota.NewNotificationDueEvent( + context.Background(), + "a.NewAggregate("quota1", "INSTANCE").Aggregate, + QuotaRequestsAllAuthenticated.Enum(), + "id1", + "url", + time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC), + 1000, + 250, + ), + quota.NewNotificationDueEvent( + context.Background(), + "a.NewAggregate("quota1", "INSTANCE").Aggregate, + QuotaRequestsAllAuthenticated.Enum(), + "id2", + "url", + time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC), + 1000, + 250, + ), + quota.NewNotificationDueEvent( + context.Background(), + "a.NewAggregate("quota1", "INSTANCE").Aggregate, + QuotaRequestsAllAuthenticated.Enum(), + "id3", + "url", + time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC), + 1000, + 250, + ), + }, + }, + res: res{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &Commands{ + eventstore: tt.fields.eventstore, + } + err := r.ReportQuotaUsage(tt.args.ctx, tt.args.dueNotifications) + if tt.res.err == nil { + assert.NoError(t, err) + } + if tt.res.err != nil && !tt.res.err(err) { + t.Errorf("got wrong err: %v ", err) + } + }) + } +} + +func TestQuotaReport_UsageNotificationSent(t *testing.T) { + type fields struct { + eventstore *eventstore.Eventstore + idGenerator id.Generator + } + type args struct { + ctx context.Context + dueNotification *quota.NotificationDueEvent + } + type res struct { + err func(error) bool + } + tests := []struct { + name string + fields fields + args args + res res + }{ + { + name: "usage notification sent, ok", + fields: fields{ + eventstore: eventstoreExpect( + t, + expectPush( + []*repository.Event{ + eventFromEventPusherWithInstanceID( + "INSTANCE", + quota.NewNotifiedEvent( + context.Background(), + "quota1", + quota.NewNotificationDueEvent( + context.Background(), + "a.NewAggregate("quota1", "INSTANCE").Aggregate, + QuotaRequestsAllAuthenticated.Enum(), + "id1", + "url", + time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC), + 1000, + 250, + ), + ), + ), + }, + ), + ), + idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "quota1"), + }, + args: args{ + ctx: authz.WithInstanceID(context.Background(), "INSTANCE"), + dueNotification: quota.NewNotificationDueEvent( + context.Background(), + "a.NewAggregate("quota1", "INSTANCE").Aggregate, + QuotaRequestsAllAuthenticated.Enum(), + "id1", + "url", + time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC), + 1000, + 250, + ), + }, + res: res{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &Commands{ + eventstore: tt.fields.eventstore, + idGenerator: tt.fields.idGenerator, + } + err := r.UsageNotificationSent(tt.args.ctx, tt.args.dueNotification) + if tt.res.err == nil { + assert.NoError(t, err) + } + if tt.res.err != nil && !tt.res.err(err) { + t.Errorf("got wrong err: %v ", err) + } + }) + } +} diff --git a/internal/command/quota_test.go b/internal/command/quota_test.go new file mode 100644 index 0000000000..a3f42b5917 --- /dev/null +++ b/internal/command/quota_test.go @@ -0,0 +1,638 @@ +package command + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/domain" + caos_errors "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/eventstore/repository" + "github.com/zitadel/zitadel/internal/id" + id_mock "github.com/zitadel/zitadel/internal/id/mock" + "github.com/zitadel/zitadel/internal/repository/quota" +) + +func TestQuota_AddQuota(t *testing.T) { + type fields struct { + eventstore *eventstore.Eventstore + idGenerator id.Generator + } + type args struct { + ctx context.Context + addQuota *AddQuota + } + type res struct { + want *domain.ObjectDetails + err func(error) bool + } + tests := []struct { + name string + fields fields + args args + res res + }{ + { + name: "already existing", + fields: fields{ + eventstore: eventstoreExpect( + t, + expectFilter( + eventFromEventPusher( + quota.NewAddedEvent(context.Background(), + "a.NewAggregate("quota1", "INSTANCE").Aggregate, + QuotaRequestsAllAuthenticated.Enum(), + time.Now(), + 30*24*time.Hour, + 1000, + false, + nil, + ), + ), + ), + ), + }, + args: args{ + ctx: authz.WithInstanceID(context.Background(), "INSTANCE"), + addQuota: &AddQuota{ + Unit: QuotaRequestsAllAuthenticated, + From: time.Time{}, + ResetInterval: 0, + Amount: 0, + Limit: false, + Notifications: nil, + }, + }, + res: res{ + err: caos_errors.IsErrorAlreadyExists, + }, + }, + { + name: "create quota, validation fail", + fields: fields{ + eventstore: eventstoreExpect( + t, + expectFilter(), + ), + idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "quota1"), + }, + args: args{ + ctx: authz.WithInstanceID(context.Background(), "INSTANCE"), + addQuota: &AddQuota{ + Unit: "unimplemented", + From: time.Time{}, + ResetInterval: 0, + Amount: 0, + Limit: false, + Notifications: nil, + }, + }, + res: res{ + err: func(err error) bool { + return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-OTeSh", "")) + }, + }, + }, + { + name: "create quota, ok", + fields: fields{ + eventstore: eventstoreExpect( + t, + expectFilter(), + expectPush( + []*repository.Event{ + eventFromEventPusherWithInstanceID( + "INSTANCE", + quota.NewAddedEvent(context.Background(), + "a.NewAggregate("quota1", "INSTANCE").Aggregate, + QuotaRequestsAllAuthenticated.Enum(), + time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC), + 30*24*time.Hour, + 1000, + true, + nil, + ), + ), + }, + uniqueConstraintsFromEventConstraintWithInstanceID("INSTANCE", quota.NewAddQuotaUnitUniqueConstraint(quota.RequestsAllAuthenticated)), + ), + ), + idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "quota1"), + }, + args: args{ + ctx: authz.WithInstanceID(context.Background(), "INSTANCE"), + addQuota: &AddQuota{ + Unit: QuotaRequestsAllAuthenticated, + From: time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC), + ResetInterval: 30 * 24 * time.Hour, + Amount: 1000, + Limit: true, + Notifications: nil, + }, + }, + res: res{ + want: &domain.ObjectDetails{ + ResourceOwner: "INSTANCE", + }, + }, + }, + { + name: "removed, ok", + fields: fields{ + eventstore: eventstoreExpect( + t, + expectFilter( + eventFromEventPusherWithInstanceID( + "INSTANCE", + quota.NewAddedEvent(context.Background(), + "a.NewAggregate("quota1", "INSTANCE").Aggregate, + QuotaRequestsAllAuthenticated.Enum(), + time.Now(), + 30*24*time.Hour, + 1000, + true, + nil, + ), + ), + eventFromEventPusherWithInstanceID( + "INSTANCE", + quota.NewRemovedEvent(context.Background(), + "a.NewAggregate("quota1", "INSTANCE").Aggregate, + QuotaRequestsAllAuthenticated.Enum(), + ), + ), + ), + expectPush( + []*repository.Event{ + eventFromEventPusherWithInstanceID( + "INSTANCE", + quota.NewAddedEvent(context.Background(), + "a.NewAggregate("quota1", "INSTANCE").Aggregate, + QuotaRequestsAllAuthenticated.Enum(), + time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC), + 30*24*time.Hour, + 1000, + true, + nil, + ), + ), + }, + uniqueConstraintsFromEventConstraintWithInstanceID("INSTANCE", quota.NewAddQuotaUnitUniqueConstraint(quota.RequestsAllAuthenticated)), + ), + ), + idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "quota1"), + }, + args: args{ + ctx: authz.WithInstanceID(context.Background(), "INSTANCE"), + addQuota: &AddQuota{ + Unit: QuotaRequestsAllAuthenticated, + From: time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC), + ResetInterval: 30 * 24 * time.Hour, + Amount: 1000, + Limit: true, + Notifications: nil, + }, + }, + res: res{ + want: &domain.ObjectDetails{ + ResourceOwner: "INSTANCE", + }, + }, + }, + { + name: "create quota with notifications, ok", + fields: fields{ + eventstore: eventstoreExpect( + t, + expectFilter(), + expectPush( + []*repository.Event{ + eventFromEventPusherWithInstanceID( + "INSTANCE", + quota.NewAddedEvent(context.Background(), + "a.NewAggregate("quota1", "INSTANCE").Aggregate, + QuotaRequestsAllAuthenticated.Enum(), + time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC), + 30*24*time.Hour, + 1000, + true, + []*quota.AddedEventNotification{ + { + ID: "notification1", + Percent: 20, + Repeat: false, + CallURL: "https://url.com", + }, + }, + ), + ), + }, + uniqueConstraintsFromEventConstraintWithInstanceID("INSTANCE", quota.NewAddQuotaUnitUniqueConstraint(quota.RequestsAllAuthenticated)), + ), + ), + idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "quota1", "notification1"), + }, + args: args{ + ctx: authz.WithInstanceID(context.Background(), "INSTANCE"), + addQuota: &AddQuota{ + Unit: QuotaRequestsAllAuthenticated, + From: time.Date(2023, 9, 1, 0, 0, 0, 0, time.UTC), + ResetInterval: 30 * 24 * time.Hour, + Amount: 1000, + Limit: true, + Notifications: QuotaNotifications{ + { + Percent: 20, + Repeat: false, + CallURL: "https://url.com", + }, + }, + }, + }, + res: res{ + want: &domain.ObjectDetails{ + ResourceOwner: "INSTANCE", + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &Commands{ + eventstore: tt.fields.eventstore, + idGenerator: tt.fields.idGenerator, + } + got, err := r.AddQuota(tt.args.ctx, tt.args.addQuota) + if tt.res.err == nil { + assert.NoError(t, err) + } + if tt.res.err != nil && !tt.res.err(err) { + t.Errorf("got wrong err: %v ", err) + } + if tt.res.err == nil { + assert.Equal(t, tt.res.want, got) + } + }) + } +} + +func TestQuota_RemoveQuota(t *testing.T) { + type fields struct { + eventstore *eventstore.Eventstore + } + type args struct { + ctx context.Context + unit QuotaUnit + } + type res struct { + want *domain.ObjectDetails + err func(error) bool + } + tests := []struct { + name string + fields fields + args args + res res + }{ + { + name: "not found", + fields: fields{ + eventstore: eventstoreExpect( + t, + expectFilter(), + ), + }, + args: args{ + ctx: authz.WithInstanceID(context.Background(), "INSTANCE"), + unit: QuotaRequestsAllAuthenticated, + }, + res: res{ + err: func(err error) bool { + return errors.Is(err, caos_errors.ThrowNotFound(nil, "COMMAND-WDfFf", "")) + }, + }, + }, + { + name: "already removed", + fields: fields{ + eventstore: eventstoreExpect( + t, + expectFilter( + eventFromEventPusherWithInstanceID( + "INSTANCE", + quota.NewAddedEvent(context.Background(), + "a.NewAggregate("quota1", "INSTANCE").Aggregate, + QuotaRequestsAllAuthenticated.Enum(), + time.Now(), + 30*24*time.Hour, + 1000, + true, + nil, + ), + ), + eventFromEventPusherWithInstanceID( + "INSTANCE", + quota.NewRemovedEvent(context.Background(), + "a.NewAggregate("quota1", "INSTANCE").Aggregate, + QuotaRequestsAllAuthenticated.Enum(), + ), + ), + ), + ), + }, + args: args{ + ctx: authz.WithInstanceID(context.Background(), "INSTANCE"), + unit: QuotaRequestsAllAuthenticated, + }, + res: res{ + err: func(err error) bool { + return errors.Is(err, caos_errors.ThrowNotFound(nil, "COMMAND-WDfFf", "")) + }, + }, + }, + { + name: "remove quota, ok", + fields: fields{ + eventstore: eventstoreExpect( + t, + expectFilter( + eventFromEventPusherWithInstanceID( + "INSTANCE", + quota.NewAddedEvent(context.Background(), + "a.NewAggregate("quota1", "INSTANCE").Aggregate, + QuotaRequestsAllAuthenticated.Enum(), + time.Now(), + 30*24*time.Hour, + 1000, + false, + nil, + ), + ), + ), + expectPush( + []*repository.Event{ + eventFromEventPusherWithInstanceID( + "INSTANCE", + quota.NewRemovedEvent(context.Background(), + "a.NewAggregate("quota1", "INSTANCE").Aggregate, + QuotaRequestsAllAuthenticated.Enum(), + ), + ), + }, + uniqueConstraintsFromEventConstraintWithInstanceID("INSTANCE", quota.NewRemoveQuotaNameUniqueConstraint(quota.RequestsAllAuthenticated)), + ), + ), + }, + args: args{ + ctx: authz.WithInstanceID(context.Background(), "INSTANCE"), + unit: QuotaRequestsAllAuthenticated, + }, + res: res{ + want: &domain.ObjectDetails{ + ResourceOwner: "INSTANCE", + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &Commands{ + eventstore: tt.fields.eventstore, + } + got, err := r.RemoveQuota(tt.args.ctx, tt.args.unit) + if tt.res.err == nil { + assert.NoError(t, err) + } + if tt.res.err != nil && !tt.res.err(err) { + t.Errorf("got wrong err: %v ", err) + } + if tt.res.err == nil { + assert.Equal(t, tt.res.want, got) + } + }) + } +} + +func TestQuota_QuotaNotification_validate(t *testing.T) { + type args struct { + quotaNotification *QuotaNotification + } + type res struct { + err func(error) bool + } + tests := []struct { + name string + args args + res res + }{ + { + name: "notification url parse failed", + args: args{ + quotaNotification: &QuotaNotification{ + Percent: 20, + Repeat: false, + CallURL: "%", + }, + }, + res: res{ + err: func(err error) bool { + return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-bZ0Fj", "")) + }, + }, + }, + { + name: "notification url parse empty schema", + args: args{ + quotaNotification: &QuotaNotification{ + Percent: 20, + Repeat: false, + CallURL: "localhost:8080", + }, + }, + res: res{ + err: func(err error) bool { + return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-HAYmN", "")) + }, + }, + }, + { + name: "notification url parse empty host", + args: args{ + quotaNotification: &QuotaNotification{ + Percent: 20, + Repeat: false, + CallURL: "https://", + }, + }, + res: res{ + err: func(err error) bool { + return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-HAYmN", "")) + }, + }, + }, + { + name: "notification url parse percent 0", + args: args{ + quotaNotification: &QuotaNotification{ + Percent: 0, + Repeat: false, + CallURL: "https://localhost:8080", + }, + }, + res: res{ + err: func(err error) bool { + return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-pBfjq", "")) + }, + }, + }, + { + name: "notification, ok", + args: args{ + quotaNotification: &QuotaNotification{ + Percent: 20, + Repeat: false, + CallURL: "https://localhost:8080", + }, + }, + res: res{ + err: nil, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.args.quotaNotification.validate() + if tt.res.err == nil { + assert.NoError(t, err) + } + if tt.res.err != nil && !tt.res.err(err) { + t.Errorf("got wrong err: %v ", err) + } + }) + } +} + +func TestQuota_AddQuota_validate(t *testing.T) { + type args struct { + addQuota *AddQuota + } + type res struct { + err func(error) bool + } + tests := []struct { + name string + args args + res res + }{ + { + name: "notification url parse failed", + args: args{ + addQuota: &AddQuota{ + Unit: QuotaRequestsAllAuthenticated, + From: time.Now(), + ResetInterval: time.Minute * 10, + Amount: 100, + Limit: true, + Notifications: QuotaNotifications{ + { + Percent: 20, + Repeat: false, + CallURL: "%", + }, + }, + }, + }, + res: res{ + err: func(err error) bool { + return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-bZ0Fj", "")) + }, + }, + }, + { + name: "unit unimplemented", + args: args{ + addQuota: &AddQuota{ + Unit: "unimplemented", + From: time.Now(), + ResetInterval: time.Minute * 10, + Amount: 100, + Limit: true, + Notifications: nil, + }, + }, + res: res{ + err: func(err error) bool { + return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-OTeSh", "")) + }, + }, + }, + { + name: "amount 0", + args: args{ + addQuota: &AddQuota{ + Unit: QuotaRequestsAllAuthenticated, + From: time.Now(), + ResetInterval: time.Minute * 10, + Amount: 0, + Limit: true, + Notifications: nil, + }, + }, + res: res{ + err: func(err error) bool { + return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-hOKSJ", "")) + }, + }, + }, + { + name: "reset interval under 1 min", + args: args{ + addQuota: &AddQuota{ + Unit: QuotaRequestsAllAuthenticated, + From: time.Now(), + ResetInterval: time.Second * 10, + Amount: 100, + Limit: true, + Notifications: nil, + }, + }, + res: res{ + err: func(err error) bool { + return errors.Is(err, caos_errors.ThrowInvalidArgument(nil, "QUOTA-R5otd", "")) + }, + }, + }, + { + name: "validate, ok", + args: args{ + addQuota: &AddQuota{ + Unit: QuotaRequestsAllAuthenticated, + From: time.Now(), + ResetInterval: time.Minute * 10, + Amount: 100, + Limit: false, + Notifications: nil, + }, + }, + res: res{ + err: nil, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.args.addQuota.validate() + if tt.res.err == nil { + assert.NoError(t, err) + } + if tt.res.err != nil && !tt.res.err(err) { + t.Errorf("got wrong err: %v ", err) + } + }) + } +} diff --git a/internal/database/type.go b/internal/database/type.go index 8360bc1f13..4b863c1010 100644 --- a/internal/database/type.go +++ b/internal/database/type.go @@ -3,6 +3,7 @@ package database import ( "database/sql/driver" "encoding/json" + "time" "github.com/jackc/pgtype" ) @@ -93,3 +94,15 @@ func (m Map[V]) Value() (driver.Value, error) { } return json.Marshal(m) } + +type Duration time.Duration + +// Scan implements the [database/sql.Scanner] interface. +func (d *Duration) Scan(src any) error { + interval := new(pgtype.Interval) + if err := interval.Scan(src); err != nil { + return err + } + *d = Duration(time.Duration(interval.Microseconds*1000) + time.Duration(interval.Days)*24*time.Hour + time.Duration(interval.Months)*30*24*time.Hour) + return nil +} diff --git a/internal/integration/assert.go b/internal/integration/assert.go index 5c19eea986..3638c831a3 100644 --- a/internal/integration/assert.go +++ b/internal/integration/assert.go @@ -24,8 +24,8 @@ type DetailsMsg interface { // // The resource owner is compared with expected and is // therefore the only value that has to be set. -func AssertDetails[D DetailsMsg](t testing.TB, exptected, actual D) { - wantDetails, gotDetails := exptected.GetDetails(), actual.GetDetails() +func AssertDetails[D DetailsMsg](t testing.TB, expected, actual D) { + wantDetails, gotDetails := expected.GetDetails(), actual.GetDetails() if wantDetails == nil { assert.Nil(t, gotDetails) return diff --git a/internal/integration/config/zitadel.yaml b/internal/integration/config/zitadel.yaml index 14222e263d..22d77a033c 100644 --- a/internal/integration/config/zitadel.yaml +++ b/internal/integration/config/zitadel.yaml @@ -22,22 +22,17 @@ FirstInstance: PasswordChangeRequired: false LogStore: - Access: - Database: - Enabled: true - Debounce: - MinFrequency: 0s - MaxBulkSize: 0 Execution: - Database: - Enabled: true Stdout: Enabled: true Quotas: Access: + Enabled: true ExhaustedCookieKey: "zitadel.quota.limiting" ExhaustedCookieMaxAge: "60s" + Execution: + Enabled: true Projections: Customizations: diff --git a/internal/integration/integration.go b/internal/integration/integration.go index d652311190..3e270b1d72 100644 --- a/internal/integration/integration.go +++ b/internal/integration/integration.go @@ -8,7 +8,11 @@ import ( _ "embed" "errors" "fmt" + "io" + "net/http" + "net/http/httptest" "os" + "reflect" "strings" "sync" "time" @@ -30,6 +34,7 @@ import ( "github.com/zitadel/zitadel/internal/domain" caos_errs "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/eventstore/v1/models" + "github.com/zitadel/zitadel/internal/net" "github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/webauthn" "github.com/zitadel/zitadel/pkg/grpc/admin" @@ -71,6 +76,11 @@ const ( UserPassword = "VeryS3cret!" ) +const ( + PortMilestoneServer = "8081" + PortQuotaServer = "8082" +) + // User information with a Personal Access Token. type User struct { *query.User @@ -101,6 +111,11 @@ type Tester struct { Organisation *query.Org Users InstanceUserMap + MilestoneChan chan []byte + milestoneServer *httptest.Server + QuotaNotificationChan chan []byte + quotaNotificationServer *httptest.Server + Client Client WebAuthN *webauthn.Client wg sync.WaitGroup // used for shutdown @@ -271,6 +286,8 @@ func (s *Tester) Done() { s.Shutdown <- os.Interrupt s.wg.Wait() + s.milestoneServer.Close() + s.quotaNotificationServer.Close() } // NewTester start a new Zitadel server by passing the default commandline. @@ -279,13 +296,13 @@ func (s *Tester) Done() { // INTEGRATION_DB_FLAVOR environment variable and can have the values "cockroach" // or "postgres". Defaults to "cockroach". // -// The deault Instance and Organisation are read from the DB and system +// The default Instance and Organisation are read from the DB and system // users are created as needed. // // After the server is started, a [grpc.ClientConn] will be created and // the server is polled for it's health status. // -// Note: the database must already be setup and intialized before +// Note: the database must already be setup and initialized before // using NewTester. See the CONTRIBUTING.md document for details. func NewTester(ctx context.Context) *Tester { args := strings.Split(commandLine, " ") @@ -311,6 +328,13 @@ func NewTester(ctx context.Context) *Tester { tester := Tester{ Users: make(InstanceUserMap), } + tester.MilestoneChan = make(chan []byte, 100) + tester.milestoneServer, err = runMilestoneServer(ctx, tester.MilestoneChan) + logging.OnError(err).Fatal() + tester.QuotaNotificationChan = make(chan []byte, 100) + tester.quotaNotificationServer, err = runQuotaServer(ctx, tester.QuotaNotificationChan) + logging.OnError(err).Fatal() + tester.wg.Add(1) go func(wg *sync.WaitGroup) { logging.OnError(cmd.Execute()).Fatal() @@ -328,7 +352,6 @@ func NewTester(ctx context.Context) *Tester { tester.createMachineUserOrgOwner(ctx) tester.createMachineUserInstanceOwner(ctx) tester.WebAuthN = webauthn.NewClient(tester.Config.WebAuthNName, tester.Config.ExternalDomain, "https://"+tester.Host()) - return &tester } @@ -338,3 +361,51 @@ func Contexts(timeout time.Duration) (ctx, errCtx context.Context, cancel contex ctx, cancel = context.WithTimeout(context.Background(), timeout) return ctx, errCtx, cancel } + +func runMilestoneServer(ctx context.Context, bodies chan []byte) (*httptest.Server, error) { + mockServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if r.Header.Get("single-value") != "single-value" { + http.Error(w, "single-value header not set", http.StatusInternalServerError) + return + } + if reflect.DeepEqual(r.Header.Get("multi-value"), "multi-value-1,multi-value-2") { + http.Error(w, "single-value header not set", http.StatusInternalServerError) + return + } + bodies <- body + w.WriteHeader(http.StatusOK) + })) + config := net.ListenConfig() + listener, err := config.Listen(ctx, "tcp", ":"+PortMilestoneServer) + if err != nil { + return nil, err + } + mockServer.Listener = listener + mockServer.Start() + return mockServer, nil +} + +func runQuotaServer(ctx context.Context, bodies chan []byte) (*httptest.Server, error) { + mockServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + bodies <- body + w.WriteHeader(http.StatusOK) + })) + config := net.ListenConfig() + listener, err := config.Listen(ctx, "tcp", ":"+PortQuotaServer) + if err != nil { + return nil, err + } + mockServer.Listener = listener + mockServer.Start() + return mockServer, nil +} diff --git a/internal/logstore/config.go b/internal/logstore/config.go index 3e8a8d327a..4f1419d532 100644 --- a/internal/logstore/config.go +++ b/internal/logstore/config.go @@ -6,6 +6,9 @@ type Configs struct { } type Config struct { - Database *EmitterConfig - Stdout *EmitterConfig + Stdout *StdConfig +} + +type StdConfig struct { + Enabled bool } diff --git a/internal/logstore/debouncer.go b/internal/logstore/debouncer.go index e299c08a8f..9c8fefb0d8 100644 --- a/internal/logstore/debouncer.go +++ b/internal/logstore/debouncer.go @@ -9,19 +9,17 @@ import ( "github.com/zitadel/logging" ) -type bulkSink interface { - sendBulk(ctx context.Context, bulk []LogRecord) error +type bulkSink[T LogRecord[T]] interface { + SendBulk(ctx context.Context, bulk []T) error } -var _ bulkSink = bulkSinkFunc(nil) +type bulkSinkFunc[T LogRecord[T]] func(ctx context.Context, bulk []T) error -type bulkSinkFunc func(ctx context.Context, items []LogRecord) error - -func (s bulkSinkFunc) sendBulk(ctx context.Context, items []LogRecord) error { - return s(ctx, items) +func (s bulkSinkFunc[T]) SendBulk(ctx context.Context, bulk []T) error { + return s(ctx, bulk) } -type debouncer struct { +type debouncer[T LogRecord[T]] struct { // Storing context.Context in a struct is generally bad practice // https://go.dev/blog/context-and-structs // However, debouncer starts a go routine that triggers side effects itself. @@ -33,8 +31,8 @@ type debouncer struct { ticker *clock.Ticker mux sync.Mutex cfg DebouncerConfig - storage bulkSink - cache []LogRecord + storage bulkSink[T] + cache []T cacheLen uint } @@ -43,8 +41,8 @@ type DebouncerConfig struct { MaxBulkSize uint } -func newDebouncer(binarySignaledCtx context.Context, cfg DebouncerConfig, clock clock.Clock, ship bulkSink) *debouncer { - a := &debouncer{ +func newDebouncer[T LogRecord[T]](binarySignaledCtx context.Context, cfg DebouncerConfig, clock clock.Clock, ship bulkSink[T]) *debouncer[T] { + a := &debouncer[T]{ binarySignaledCtx: binarySignaledCtx, clock: clock, cfg: cfg, @@ -58,7 +56,7 @@ func newDebouncer(binarySignaledCtx context.Context, cfg DebouncerConfig, clock return a } -func (d *debouncer) add(item LogRecord) { +func (d *debouncer[T]) add(item T) { d.mux.Lock() defer d.mux.Unlock() d.cache = append(d.cache, item) @@ -69,13 +67,13 @@ func (d *debouncer) add(item LogRecord) { } } -func (d *debouncer) ship() { +func (d *debouncer[T]) ship() { if d.cacheLen == 0 { return } d.mux.Lock() defer d.mux.Unlock() - if err := d.storage.sendBulk(d.binarySignaledCtx, d.cache); err != nil { + if err := d.storage.SendBulk(d.binarySignaledCtx, d.cache); err != nil { logging.WithError(err).WithField("size", len(d.cache)).Error("storing bulk failed") } d.cache = nil @@ -85,7 +83,7 @@ func (d *debouncer) ship() { } } -func (d *debouncer) shipOnTicks() { +func (d *debouncer[T]) shipOnTicks() { for range d.ticker.C { d.ship() } diff --git a/internal/logstore/emitter.go b/internal/logstore/emitter.go index a7a5151afb..f35efc61a3 100644 --- a/internal/logstore/emitter.go +++ b/internal/logstore/emitter.go @@ -2,56 +2,52 @@ package logstore import ( "context" - "fmt" "time" "github.com/benbjohnson/clock" - "github.com/zitadel/logging" ) type EmitterConfig struct { - Enabled bool - Keep time.Duration - CleanupInterval time.Duration - Debounce *DebouncerConfig + Enabled bool + Debounce *DebouncerConfig } -type emitter struct { +type emitter[T LogRecord[T]] struct { enabled bool ctx context.Context - debouncer *debouncer - emitter LogEmitter + debouncer *debouncer[T] + emitter LogEmitter[T] clock clock.Clock } -type LogRecord interface { - Normalize() LogRecord +type LogRecord[T any] interface { + Normalize() T } -type LogRecordFunc func() LogRecord +type LogRecordFunc[T any] func() T -func (r LogRecordFunc) Normalize() LogRecord { +func (r LogRecordFunc[T]) Normalize() T { return r() } -type LogEmitter interface { - Emit(ctx context.Context, bulk []LogRecord) error +type LogEmitter[T LogRecord[T]] interface { + Emit(ctx context.Context, bulk []T) error } -type LogEmitterFunc func(ctx context.Context, bulk []LogRecord) error +type LogEmitterFunc[T LogRecord[T]] func(ctx context.Context, bulk []T) error -func (l LogEmitterFunc) Emit(ctx context.Context, bulk []LogRecord) error { +func (l LogEmitterFunc[T]) Emit(ctx context.Context, bulk []T) error { return l(ctx, bulk) } -type LogCleanupper interface { - LogEmitter +type LogCleanupper[T LogRecord[T]] interface { Cleanup(ctx context.Context, keep time.Duration) error + LogEmitter[T] } // NewEmitter accepts Clock from github.com/benbjohnson/clock so we can control timers and tickers in the unit tests -func NewEmitter(ctx context.Context, clock clock.Clock, cfg *EmitterConfig, logger LogEmitter) (*emitter, error) { - svc := &emitter{ +func NewEmitter[T LogRecord[T]](ctx context.Context, clock clock.Clock, cfg *EmitterConfig, logger LogEmitter[T]) (*emitter[T], error) { + svc := &emitter[T]{ enabled: cfg != nil && cfg.Enabled, ctx: ctx, emitter: logger, @@ -63,36 +59,12 @@ func NewEmitter(ctx context.Context, clock clock.Clock, cfg *EmitterConfig, logg } if cfg.Debounce != nil && (cfg.Debounce.MinFrequency > 0 || cfg.Debounce.MaxBulkSize > 0) { - svc.debouncer = newDebouncer(ctx, *cfg.Debounce, clock, newStorageBulkSink(svc.emitter)) - } - - cleanupper, ok := logger.(LogCleanupper) - if !ok { - if cfg.Keep != 0 { - return nil, fmt.Errorf("cleaning up for this storage type is not supported, so keep duration must be 0, but is %d", cfg.Keep) - } - if cfg.CleanupInterval != 0 { - return nil, fmt.Errorf("cleaning up for this storage type is not supported, so cleanup interval duration must be 0, but is %d", cfg.Keep) - } - - return svc, nil - } - - if cfg.Keep != 0 && cfg.CleanupInterval != 0 { - go svc.startCleanupping(cleanupper, cfg.CleanupInterval, cfg.Keep) + svc.debouncer = newDebouncer[T](ctx, *cfg.Debounce, clock, newStorageBulkSink(svc.emitter)) } return svc, nil } -func (s *emitter) startCleanupping(cleanupper LogCleanupper, cleanupInterval, keep time.Duration) { - for range s.clock.Tick(cleanupInterval) { - if err := cleanupper.Cleanup(s.ctx, keep); err != nil { - logging.WithError(err).Error("cleaning up logs failed") - } - } -} - -func (s *emitter) Emit(ctx context.Context, record LogRecord) (err error) { +func (s *emitter[T]) Emit(ctx context.Context, record T) (err error) { if !s.enabled { return nil } @@ -102,11 +74,11 @@ func (s *emitter) Emit(ctx context.Context, record LogRecord) (err error) { return nil } - return s.emitter.Emit(ctx, []LogRecord{record}) + return s.emitter.Emit(ctx, []T{record}) } -func newStorageBulkSink(emitter LogEmitter) bulkSinkFunc { - return func(ctx context.Context, bulk []LogRecord) error { +func newStorageBulkSink[T LogRecord[T]](emitter LogEmitter[T]) bulkSinkFunc[T] { + return func(ctx context.Context, bulk []T) error { return emitter.Emit(ctx, bulk) } } diff --git a/internal/logstore/emitters/access/database.go b/internal/logstore/emitters/access/database.go index de3afe5aee..dc60f49224 100644 --- a/internal/logstore/emitters/access/database.go +++ b/internal/logstore/emitters/access/database.go @@ -3,167 +3,91 @@ package access import ( "context" "database/sql" - "fmt" - "net/http" - "strings" + "errors" "time" - "github.com/Masterminds/squirrel" - "github.com/zitadel/logging" - "google.golang.org/grpc/codes" - - "github.com/zitadel/zitadel/internal/api/call" - zitadel_http "github.com/zitadel/zitadel/internal/api/http" + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/database" - caos_errors "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/logstore" + "github.com/zitadel/zitadel/internal/logstore/record" + "github.com/zitadel/zitadel/internal/query" + "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/repository/quota" + "github.com/zitadel/zitadel/internal/telemetry/tracing" ) -const ( - accessLogsTable = "logstore.access" - accessTimestampCol = "log_date" - accessProtocolCol = "protocol" - accessRequestURLCol = "request_url" - accessResponseStatusCol = "response_status" - accessRequestHeadersCol = "request_headers" - accessResponseHeadersCol = "response_headers" - accessInstanceIdCol = "instance_id" - accessProjectIdCol = "project_id" - accessRequestedDomainCol = "requested_domain" - accessRequestedHostCol = "requested_host" -) - -var _ logstore.UsageQuerier = (*databaseLogStorage)(nil) -var _ logstore.LogCleanupper = (*databaseLogStorage)(nil) +var _ logstore.UsageStorer[*record.AccessLog] = (*databaseLogStorage)(nil) type databaseLogStorage struct { dbClient *database.DB + commands *command.Commands + queries *query.Queries } -func NewDatabaseLogStorage(dbClient *database.DB) *databaseLogStorage { - return &databaseLogStorage{dbClient: dbClient} +func NewDatabaseLogStorage(dbClient *database.DB, commands *command.Commands, queries *query.Queries) *databaseLogStorage { + return &databaseLogStorage{dbClient: dbClient, commands: commands, queries: queries} } func (l *databaseLogStorage) QuotaUnit() quota.Unit { return quota.RequestsAllAuthenticated } -func (l *databaseLogStorage) Emit(ctx context.Context, bulk []logstore.LogRecord) error { +func (l *databaseLogStorage) Emit(ctx context.Context, bulk []*record.AccessLog) error { if len(bulk) == 0 { return nil } - builder := squirrel.Insert(accessLogsTable). - Columns( - accessTimestampCol, - accessProtocolCol, - accessRequestURLCol, - accessResponseStatusCol, - accessRequestHeadersCol, - accessResponseHeadersCol, - accessInstanceIdCol, - accessProjectIdCol, - accessRequestedDomainCol, - accessRequestedHostCol, - ). - PlaceholderFormat(squirrel.Dollar) - - for idx := range bulk { - item := bulk[idx].(*Record) - builder = builder.Values( - item.LogDate, - item.Protocol, - item.RequestURL, - item.ResponseStatus, - item.RequestHeaders, - item.ResponseHeaders, - item.InstanceID, - item.ProjectID, - item.RequestedDomain, - item.RequestedHost, - ) - } - - stmt, args, err := builder.ToSql() - if err != nil { - return caos_errors.ThrowInternal(err, "ACCESS-KOS7I", "Errors.Internal") - } - - result, err := l.dbClient.ExecContext(ctx, stmt, args...) - if err != nil { - return caos_errors.ThrowInternal(err, "ACCESS-alnT9", "Errors.Access.StorageFailed") - } - - rows, err := result.RowsAffected() - if err != nil { - return caos_errors.ThrowInternal(err, "ACCESS-7KIpL", "Errors.Internal") - } - - logging.WithFields("rows", rows).Debug("successfully stored access logs") - return nil + return l.incrementUsage(ctx, bulk) } -func (l *databaseLogStorage) QueryUsage(ctx context.Context, instanceId string, start time.Time) (uint64, error) { - stmt, args, err := squirrel.Select( - fmt.Sprintf("count(%s)", accessInstanceIdCol), - ). - From(accessLogsTable + l.dbClient.Timetravel(call.Took(ctx))). - Where(squirrel.And{ - squirrel.Eq{accessInstanceIdCol: instanceId}, - squirrel.GtOrEq{accessTimestampCol: start}, - squirrel.Expr(fmt.Sprintf(`%s #>> '{%s,0}' = '[REDACTED]'`, accessRequestHeadersCol, strings.ToLower(zitadel_http.Authorization))), - squirrel.NotLike{accessRequestURLCol: "%/zitadel.system.v1.SystemService/%"}, - squirrel.NotLike{accessRequestURLCol: "%/system/v1/%"}, - squirrel.Or{ - squirrel.And{ - squirrel.Eq{accessProtocolCol: HTTP}, - squirrel.NotEq{accessResponseStatusCol: http.StatusForbidden}, - squirrel.NotEq{accessResponseStatusCol: http.StatusInternalServerError}, - squirrel.NotEq{accessResponseStatusCol: http.StatusTooManyRequests}, - }, - squirrel.And{ - squirrel.Eq{accessProtocolCol: GRPC}, - squirrel.NotEq{accessResponseStatusCol: codes.PermissionDenied}, - squirrel.NotEq{accessResponseStatusCol: codes.Internal}, - squirrel.NotEq{accessResponseStatusCol: codes.ResourceExhausted}, - }, - }, - }). - PlaceholderFormat(squirrel.Dollar). - ToSql() +func (l *databaseLogStorage) incrementUsage(ctx context.Context, bulk []*record.AccessLog) (err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() - if err != nil { - return 0, caos_errors.ThrowInternal(err, "ACCESS-V9Sde", "Errors.Internal") + byInstance := make(map[string][]*record.AccessLog) + for _, r := range bulk { + if r.InstanceID != "" { + byInstance[r.InstanceID] = append(byInstance[r.InstanceID], r) + } } - - var count uint64 - err = l.dbClient. - QueryRowContext(ctx, - func(row *sql.Row) error { - return row.Scan(&count) - }, - stmt, args..., - ) - - if err != nil { - return 0, caos_errors.ThrowInternal(err, "ACCESS-pBPrM", "Errors.Logstore.Access.ScanFailed") + for instanceID, instanceBulk := range byInstance { + q, getQuotaErr := l.queries.GetQuota(ctx, instanceID, quota.RequestsAllAuthenticated) + if errors.Is(getQuotaErr, sql.ErrNoRows) { + continue + } + err = errors.Join(err, getQuotaErr) + if getQuotaErr != nil { + continue + } + sum, incrementErr := l.incrementUsageFromAccessLogs(ctx, instanceID, q.CurrentPeriodStart, instanceBulk) + err = errors.Join(err, incrementErr) + if incrementErr != nil { + continue + } + notifications, getNotificationErr := l.queries.GetDueQuotaNotifications(ctx, instanceID, quota.RequestsAllAuthenticated, q, q.CurrentPeriodStart, sum) + err = errors.Join(err, getNotificationErr) + if getNotificationErr != nil || len(notifications) == 0 { + continue + } + ctx = authz.WithInstanceID(ctx, instanceID) + reportErr := l.commands.ReportQuotaUsage(ctx, notifications) + err = errors.Join(err, reportErr) + if reportErr != nil { + continue + } } - - return count, nil -} - -func (l *databaseLogStorage) Cleanup(ctx context.Context, keep time.Duration) error { - stmt, args, err := squirrel.Delete(accessLogsTable). - Where(squirrel.LtOrEq{accessTimestampCol: time.Now().Add(-keep)}). - PlaceholderFormat(squirrel.Dollar). - ToSql() - - if err != nil { - return caos_errors.ThrowInternal(err, "ACCESS-2oTh6", "Errors.Internal") - } - - execCtx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - _, err = l.dbClient.ExecContext(execCtx, stmt, args...) return err } + +func (l *databaseLogStorage) incrementUsageFromAccessLogs(ctx context.Context, instanceID string, periodStart time.Time, records []*record.AccessLog) (sum uint64, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + var count uint64 + for _, r := range records { + if r.IsAuthenticated() { + count++ + } + } + return projection.QuotaProjection.IncrementUsage(ctx, quota.RequestsAllAuthenticated, instanceID, periodStart, count) +} diff --git a/internal/logstore/emitters/access/record.go b/internal/logstore/emitters/access/record.go deleted file mode 100644 index 6b802015da..0000000000 --- a/internal/logstore/emitters/access/record.go +++ /dev/null @@ -1,97 +0,0 @@ -package access - -import ( - "strings" - "time" - - zitadel_http "github.com/zitadel/zitadel/internal/api/http" - "github.com/zitadel/zitadel/internal/logstore" -) - -var _ logstore.LogRecord = (*Record)(nil) - -type Record struct { - LogDate time.Time `json:"logDate"` - Protocol Protocol `json:"protocol"` - RequestURL string `json:"requestUrl"` - ResponseStatus uint32 `json:"responseStatus"` - // RequestHeaders are plain maps so varying implementation - // between HTTP and gRPC don't interfere with each other - RequestHeaders map[string][]string `json:"requestHeaders"` - // ResponseHeaders are plain maps so varying implementation - // between HTTP and gRPC don't interfere with each other - ResponseHeaders map[string][]string `json:"responseHeaders"` - InstanceID string `json:"instanceId"` - ProjectID string `json:"projectId"` - RequestedDomain string `json:"requestedDomain"` - RequestedHost string `json:"requestedHost"` -} - -type Protocol uint8 - -const ( - GRPC Protocol = iota - HTTP - - redacted = "[REDACTED]" -) - -func (a Record) Normalize() logstore.LogRecord { - a.RequestedDomain = cutString(a.RequestedDomain, 200) - a.RequestURL = cutString(a.RequestURL, 200) - a.RequestHeaders = normalizeHeaders(a.RequestHeaders, strings.ToLower(zitadel_http.Authorization), "grpcgateway-authorization", "cookie", "grpcgateway-cookie") - a.ResponseHeaders = normalizeHeaders(a.ResponseHeaders, "set-cookie") - return &a -} - -// normalizeHeaders lowers all header keys and redacts secrets -func normalizeHeaders(header map[string][]string, redactKeysLower ...string) map[string][]string { - return pruneKeys(redactKeys(lowerKeys(header), redactKeysLower...)) -} - -func lowerKeys(header map[string][]string) map[string][]string { - lower := make(map[string][]string, len(header)) - for k, v := range header { - lower[strings.ToLower(k)] = v - } - return lower -} - -func redactKeys(header map[string][]string, redactKeysLower ...string) map[string][]string { - redactedKeys := make(map[string][]string, len(header)) - for k, v := range header { - redactedKeys[k] = v - } - for _, redactKey := range redactKeysLower { - if _, ok := redactedKeys[redactKey]; ok { - redactedKeys[redactKey] = []string{redacted} - } - } - return redactedKeys -} - -const maxValuesPerKey = 10 - -func pruneKeys(header map[string][]string) map[string][]string { - prunedKeys := make(map[string][]string, len(header)) - for key, value := range header { - valueItems := make([]string, 0, maxValuesPerKey) - for i, valueItem := range value { - // Max 10 header values per key - if i > maxValuesPerKey { - break - } - // Max 200 value length - valueItems = append(valueItems, cutString(valueItem, 200)) - } - prunedKeys[key] = valueItems - } - return prunedKeys -} - -func cutString(str string, pos int) string { - if len(str) <= pos { - return str - } - return str[:pos-1] -} diff --git a/internal/logstore/emitters/execution/database.go b/internal/logstore/emitters/execution/database.go index 396f3f2ea1..8f7985de1d 100644 --- a/internal/logstore/emitters/execution/database.go +++ b/internal/logstore/emitters/execution/database.go @@ -3,142 +3,84 @@ package execution import ( "context" "database/sql" - "fmt" + "errors" + "math" "time" - "github.com/Masterminds/squirrel" - "github.com/zitadel/logging" - - "github.com/zitadel/zitadel/internal/api/call" + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/database" - caos_errors "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/logstore" + "github.com/zitadel/zitadel/internal/logstore/record" + "github.com/zitadel/zitadel/internal/query" + "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/repository/quota" ) -const ( - executionLogsTable = "logstore.execution" - executionTimestampCol = "log_date" - executionTookCol = "took" - executionMessageCol = "message" - executionLogLevelCol = "loglevel" - executionInstanceIdCol = "instance_id" - executionActionIdCol = "action_id" - executionMetadataCol = "metadata" -) - -var _ logstore.UsageQuerier = (*databaseLogStorage)(nil) -var _ logstore.LogCleanupper = (*databaseLogStorage)(nil) +var _ logstore.UsageStorer[*record.ExecutionLog] = (*databaseLogStorage)(nil) type databaseLogStorage struct { dbClient *database.DB + commands *command.Commands + queries *query.Queries } -func NewDatabaseLogStorage(dbClient *database.DB) *databaseLogStorage { - return &databaseLogStorage{dbClient: dbClient} +func NewDatabaseLogStorage(dbClient *database.DB, commands *command.Commands, queries *query.Queries) *databaseLogStorage { + return &databaseLogStorage{dbClient: dbClient, commands: commands, queries: queries} } func (l *databaseLogStorage) QuotaUnit() quota.Unit { return quota.ActionsAllRunsSeconds } -func (l *databaseLogStorage) Emit(ctx context.Context, bulk []logstore.LogRecord) error { +func (l *databaseLogStorage) Emit(ctx context.Context, bulk []*record.ExecutionLog) error { if len(bulk) == 0 { return nil } - builder := squirrel.Insert(executionLogsTable). - Columns( - executionTimestampCol, - executionTookCol, - executionMessageCol, - executionLogLevelCol, - executionInstanceIdCol, - executionActionIdCol, - executionMetadataCol, - ). - PlaceholderFormat(squirrel.Dollar) + return l.incrementUsage(ctx, bulk) +} - for idx := range bulk { - item := bulk[idx].(*Record) - - var took interface{} - if item.Took > 0 { - took = item.Took +func (l *databaseLogStorage) incrementUsage(ctx context.Context, bulk []*record.ExecutionLog) (err error) { + byInstance := make(map[string][]*record.ExecutionLog) + for _, r := range bulk { + if r.InstanceID != "" { + byInstance[r.InstanceID] = append(byInstance[r.InstanceID], r) + } + } + for instanceID, instanceBulk := range byInstance { + q, getQuotaErr := l.queries.GetQuota(ctx, instanceID, quota.ActionsAllRunsSeconds) + if errors.Is(getQuotaErr, sql.ErrNoRows) { + continue + } + err = errors.Join(err, getQuotaErr) + if getQuotaErr != nil { + continue + } + sum, incrementErr := l.incrementUsageFromExecutionLogs(ctx, instanceID, q.CurrentPeriodStart, instanceBulk) + err = errors.Join(err, incrementErr) + if incrementErr != nil { + continue } - builder = builder.Values( - item.LogDate, - took, - item.Message, - item.LogLevel, - item.InstanceID, - item.ActionID, - item.Metadata, - ) + notifications, getNotificationErr := l.queries.GetDueQuotaNotifications(ctx, instanceID, quota.ActionsAllRunsSeconds, q, q.CurrentPeriodStart, sum) + err = errors.Join(err, getNotificationErr) + if getNotificationErr != nil || len(notifications) == 0 { + continue + } + ctx = authz.WithInstanceID(ctx, instanceID) + reportErr := l.commands.ReportQuotaUsage(ctx, notifications) + err = errors.Join(err, reportErr) + if reportErr != nil { + continue + } } - - stmt, args, err := builder.ToSql() - if err != nil { - return caos_errors.ThrowInternal(err, "EXEC-KOS7I", "Errors.Internal") - } - - result, err := l.dbClient.ExecContext(ctx, stmt, args...) - if err != nil { - return caos_errors.ThrowInternal(err, "EXEC-0j6i5", "Errors.Access.StorageFailed") - } - - rows, err := result.RowsAffected() - if err != nil { - return caos_errors.ThrowInternal(err, "EXEC-MGchJ", "Errors.Internal") - } - - logging.WithFields("rows", rows).Debug("successfully stored execution logs") - return nil -} - -func (l *databaseLogStorage) QueryUsage(ctx context.Context, instanceId string, start time.Time) (uint64, error) { - stmt, args, err := squirrel.Select( - fmt.Sprintf("COALESCE(SUM(%s)::INT,0)", executionTookCol), - ). - From(executionLogsTable + l.dbClient.Timetravel(call.Took(ctx))). - Where(squirrel.And{ - squirrel.Eq{executionInstanceIdCol: instanceId}, - squirrel.GtOrEq{executionTimestampCol: start}, - squirrel.NotEq{executionTookCol: nil}, - }). - PlaceholderFormat(squirrel.Dollar). - ToSql() - - if err != nil { - return 0, caos_errors.ThrowInternal(err, "EXEC-DXtzg", "Errors.Internal") - } - - var durationSeconds uint64 - err = l.dbClient. - QueryRowContext(ctx, - func(row *sql.Row) error { - return row.Scan(&durationSeconds) - }, - stmt, args..., - ) - if err != nil { - return 0, caos_errors.ThrowInternal(err, "EXEC-Ad8nP", "Errors.Logstore.Execution.ScanFailed") - } - return durationSeconds, nil -} - -func (l *databaseLogStorage) Cleanup(ctx context.Context, keep time.Duration) error { - stmt, args, err := squirrel.Delete(executionLogsTable). - Where(squirrel.LtOrEq{executionTimestampCol: time.Now().Add(-keep)}). - PlaceholderFormat(squirrel.Dollar). - ToSql() - - if err != nil { - return caos_errors.ThrowInternal(err, "EXEC-Bja8V", "Errors.Internal") - } - - execCtx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - _, err = l.dbClient.ExecContext(execCtx, stmt, args...) return err } + +func (l *databaseLogStorage) incrementUsageFromExecutionLogs(ctx context.Context, instanceID string, periodStart time.Time, records []*record.ExecutionLog) (sum uint64, err error) { + var total time.Duration + for _, r := range records { + total += r.Took + } + return projection.QuotaProjection.IncrementUsage(ctx, quota.ActionsAllRunsSeconds, instanceID, periodStart, uint64(math.Floor(total.Seconds()))) +} diff --git a/internal/logstore/emitters/mock/inmem.go b/internal/logstore/emitters/mock/inmem.go deleted file mode 100644 index 1d30432624..0000000000 --- a/internal/logstore/emitters/mock/inmem.go +++ /dev/null @@ -1,89 +0,0 @@ -package mock - -import ( - "context" - "sync" - "time" - - "github.com/benbjohnson/clock" - - "github.com/zitadel/zitadel/internal/logstore" - "github.com/zitadel/zitadel/internal/repository/quota" -) - -var _ logstore.UsageQuerier = (*InmemLogStorage)(nil) -var _ logstore.LogCleanupper = (*InmemLogStorage)(nil) - -type InmemLogStorage struct { - mux sync.Mutex - clock clock.Clock - emitted []*record - bulks []int -} - -func NewInMemoryStorage(clock clock.Clock) *InmemLogStorage { - return &InmemLogStorage{ - clock: clock, - emitted: make([]*record, 0), - bulks: make([]int, 0), - } -} - -func (l *InmemLogStorage) QuotaUnit() quota.Unit { - return quota.Unimplemented -} - -func (l *InmemLogStorage) Emit(_ context.Context, bulk []logstore.LogRecord) error { - if len(bulk) == 0 { - return nil - } - l.mux.Lock() - defer l.mux.Unlock() - for idx := range bulk { - l.emitted = append(l.emitted, bulk[idx].(*record)) - } - l.bulks = append(l.bulks, len(bulk)) - return nil -} - -func (l *InmemLogStorage) QueryUsage(_ context.Context, _ string, start time.Time) (uint64, error) { - l.mux.Lock() - defer l.mux.Unlock() - - var count uint64 - for _, r := range l.emitted { - if r.ts.After(start) { - count++ - } - } - return count, nil -} - -func (l *InmemLogStorage) Cleanup(_ context.Context, keep time.Duration) error { - l.mux.Lock() - defer l.mux.Unlock() - - clean := make([]*record, 0) - from := l.clock.Now().Add(-(keep + 1)) - for _, r := range l.emitted { - if r.ts.After(from) { - clean = append(clean, r) - } - } - l.emitted = clean - return nil -} - -func (l *InmemLogStorage) Bulks() []int { - l.mux.Lock() - defer l.mux.Unlock() - - return l.bulks -} - -func (l *InmemLogStorage) Len() int { - l.mux.Lock() - defer l.mux.Unlock() - - return len(l.emitted) -} diff --git a/internal/logstore/emitters/stdout/stdout.go b/internal/logstore/emitters/stdout/stdout.go index 7a2cd53e69..818cc4ca12 100644 --- a/internal/logstore/emitters/stdout/stdout.go +++ b/internal/logstore/emitters/stdout/stdout.go @@ -9,8 +9,8 @@ import ( "github.com/zitadel/zitadel/internal/logstore" ) -func NewStdoutEmitter() logstore.LogEmitter { - return logstore.LogEmitterFunc(func(ctx context.Context, bulk []logstore.LogRecord) error { +func NewStdoutEmitter[T logstore.LogRecord[T]]() logstore.LogEmitter[T] { + return logstore.LogEmitterFunc[T](func(ctx context.Context, bulk []T) error { for idx := range bulk { bytes, err := json.Marshal(bulk[idx]) if err != nil { diff --git a/internal/logstore/helpers_test.go b/internal/logstore/helpers_test.go index 90348db779..6037f7058d 100644 --- a/internal/logstore/helpers_test.go +++ b/internal/logstore/helpers_test.go @@ -4,16 +4,14 @@ import ( "time" "github.com/zitadel/zitadel/internal/logstore" - "github.com/zitadel/zitadel/internal/repository/quota" + "github.com/zitadel/zitadel/internal/query" ) type emitterOption func(config *logstore.EmitterConfig) func emitterConfig(options ...emitterOption) *logstore.EmitterConfig { cfg := &logstore.EmitterConfig{ - Enabled: true, - Keep: time.Hour, - CleanupInterval: time.Hour, + Enabled: true, Debounce: &logstore.DebouncerConfig{ MinFrequency: 0, MaxBulkSize: 0, @@ -37,17 +35,10 @@ func withDisabled() emitterOption { } } -func withCleanupping(keep, interval time.Duration) emitterOption { - return func(c *logstore.EmitterConfig) { - c.Keep = keep - c.CleanupInterval = interval - } -} +type quotaOption func(config *query.Quota) -type quotaOption func(config *quota.AddedEvent) - -func quotaConfig(quotaOptions ...quotaOption) quota.AddedEvent { - q := "a.AddedEvent{ +func quotaConfig(quotaOptions ...quotaOption) *query.Quota { + q := &query.Quota{ Amount: 90, Limit: false, ResetInterval: 90 * time.Second, @@ -56,18 +47,18 @@ func quotaConfig(quotaOptions ...quotaOption) quota.AddedEvent { for _, opt := range quotaOptions { opt(q) } - return *q + return q } func withAmountAndInterval(n uint64) quotaOption { - return func(c *quota.AddedEvent) { + return func(c *query.Quota) { c.Amount = n c.ResetInterval = time.Duration(n) * time.Second } } func withLimiting() quotaOption { - return func(c *quota.AddedEvent) { + return func(c *query.Quota) { c.Limit = true } } diff --git a/internal/logstore/mock/inmem.go b/internal/logstore/mock/inmem.go new file mode 100644 index 0000000000..15e8d12536 --- /dev/null +++ b/internal/logstore/mock/inmem.go @@ -0,0 +1,120 @@ +package mock + +import ( + "context" + "sync" + "time" + + "github.com/benbjohnson/clock" + + "github.com/zitadel/zitadel/internal/logstore" + "github.com/zitadel/zitadel/internal/query" + "github.com/zitadel/zitadel/internal/repository/quota" +) + +var _ logstore.UsageStorer[*Record] = (*InmemLogStorage)(nil) +var _ logstore.LogCleanupper[*Record] = (*InmemLogStorage)(nil) +var _ logstore.Queries = (*InmemLogStorage)(nil) + +type InmemLogStorage struct { + mux sync.Mutex + clock clock.Clock + emitted []*Record + bulks []int + quota *query.Quota +} + +func NewInMemoryStorage(clock clock.Clock, quota *query.Quota) *InmemLogStorage { + return &InmemLogStorage{ + clock: clock, + emitted: make([]*Record, 0), + bulks: make([]int, 0), + quota: quota, + } +} + +func (l *InmemLogStorage) QuotaUnit() quota.Unit { + return quota.Unimplemented +} + +func (l *InmemLogStorage) Emit(_ context.Context, bulk []*Record) error { + if len(bulk) == 0 { + return nil + } + l.mux.Lock() + defer l.mux.Unlock() + l.emitted = append(l.emitted, bulk...) + l.bulks = append(l.bulks, len(bulk)) + return nil +} + +func (l *InmemLogStorage) QueryUsage(_ context.Context, _ string, start time.Time) (uint64, error) { + l.mux.Lock() + defer l.mux.Unlock() + + var count uint64 + for _, r := range l.emitted { + if r.ts.After(start) { + count++ + } + } + return count, nil +} + +func (l *InmemLogStorage) Cleanup(_ context.Context, keep time.Duration) error { + l.mux.Lock() + defer l.mux.Unlock() + + clean := make([]*Record, 0) + from := l.clock.Now().Add(-(keep + 1)) + for _, r := range l.emitted { + if r.ts.After(from) { + clean = append(clean, r) + } + } + l.emitted = clean + return nil +} + +func (l *InmemLogStorage) Bulks() []int { + l.mux.Lock() + defer l.mux.Unlock() + + return l.bulks +} + +func (l *InmemLogStorage) Len() int { + l.mux.Lock() + defer l.mux.Unlock() + + return len(l.emitted) +} + +func (l *InmemLogStorage) GetQuota(ctx context.Context, instanceID string, unit quota.Unit) (qu *query.Quota, err error) { + return l.quota, nil +} + +func (l *InmemLogStorage) GetQuotaUsage(ctx context.Context, instanceID string, unit quota.Unit, periodStart time.Time) (usage uint64, err error) { + return uint64(l.Len()), nil +} + +func (l *InmemLogStorage) GetRemainingQuotaUsage(ctx context.Context, instanceID string, unit quota.Unit) (remaining *uint64, err error) { + if !l.quota.Limit { + return nil, nil + } + var r uint64 + used := uint64(l.Len()) + if used > l.quota.Amount { + return &r, nil + } + r = l.quota.Amount - used + return &r, nil +} + +func (l *InmemLogStorage) GetDueQuotaNotifications(ctx context.Context, instanceID string, unit quota.Unit, qu *query.Quota, periodStart time.Time, usedAbs uint64) (dueNotifications []*quota.NotificationDueEvent, err error) { + return nil, nil +} + +func (l *InmemLogStorage) ReportQuotaUsage(ctx context.Context, dueNotifications []*quota.NotificationDueEvent) error { + return nil +} diff --git a/internal/logstore/emitters/mock/record.go b/internal/logstore/mock/record.go similarity index 50% rename from internal/logstore/emitters/mock/record.go rename to internal/logstore/mock/record.go index be9fc6a1f2..2c7a919675 100644 --- a/internal/logstore/emitters/mock/record.go +++ b/internal/logstore/mock/record.go @@ -8,18 +8,18 @@ import ( "github.com/zitadel/zitadel/internal/logstore" ) -var _ logstore.LogRecord = (*record)(nil) +var _ logstore.LogRecord[*Record] = (*Record)(nil) -func NewRecord(clock clock.Clock) *record { - return &record{ts: clock.Now()} +func NewRecord(clock clock.Clock) *Record { + return &Record{ts: clock.Now()} } -type record struct { +type Record struct { ts time.Time redacted bool } -func (r record) Normalize() logstore.LogRecord { +func (r Record) Normalize() *Record { r.redacted = true return &r } diff --git a/internal/logstore/quotaqueriers/mock/noop.go b/internal/logstore/quotaqueriers/mock/noop.go deleted file mode 100644 index 4f4335077b..0000000000 --- a/internal/logstore/quotaqueriers/mock/noop.go +++ /dev/null @@ -1,28 +0,0 @@ -package mock - -import ( - "context" - "time" - - "github.com/zitadel/zitadel/internal/logstore" - "github.com/zitadel/zitadel/internal/repository/quota" -) - -var _ logstore.QuotaQuerier = (*inmemReporter)(nil) - -type inmemReporter struct { - config *quota.AddedEvent - startPeriod time.Time -} - -func NewNoopQuerier(quota *quota.AddedEvent, startPeriod time.Time) *inmemReporter { - return &inmemReporter{config: quota, startPeriod: startPeriod} -} - -func (i *inmemReporter) GetCurrentQuotaPeriod(context.Context, string, quota.Unit) (*quota.AddedEvent, time.Time, error) { - return i.config, i.startPeriod, nil -} - -func (*inmemReporter) GetDueQuotaNotifications(context.Context, *quota.AddedEvent, time.Time, uint64) ([]*quota.NotificationDueEvent, error) { - return nil, nil -} diff --git a/internal/logstore/record/access.go b/internal/logstore/record/access.go new file mode 100644 index 0000000000..8f9b17bb89 --- /dev/null +++ b/internal/logstore/record/access.go @@ -0,0 +1,135 @@ +package record + +import ( + "net/http" + "strings" + "time" + + "google.golang.org/grpc/codes" + + zitadel_http "github.com/zitadel/zitadel/internal/api/http" +) + +type AccessLog struct { + LogDate time.Time `json:"logDate"` + Protocol AccessProtocol `json:"protocol"` + RequestURL string `json:"requestUrl"` + ResponseStatus uint32 `json:"responseStatus"` + // RequestHeaders and ResponseHeaders are plain maps so varying implementations + // between HTTP and gRPC don't interfere with each other + RequestHeaders map[string][]string `json:"requestHeaders"` + ResponseHeaders map[string][]string `json:"responseHeaders"` + InstanceID string `json:"instanceId"` + ProjectID string `json:"projectId"` + RequestedDomain string `json:"requestedDomain"` + RequestedHost string `json:"requestedHost"` + // NotCountable can be used by the logging service to explicitly stating, + // that the request must not increase the amount of countable (authenticated) requests + NotCountable bool `json:"-"` + normalized bool `json:"-"` +} + +type AccessProtocol uint8 + +const ( + GRPC AccessProtocol = iota + HTTP + + redacted = "[REDACTED]" +) + +var ( + unaccountableEndpoints = []string{ + "/zitadel.system.v1.SystemService/", + "/zitadel.admin.v1.AdminService/Healthz", + "/zitadel.management.v1.ManagementService/Healthz", + "/zitadel.management.v1.ManagementService/GetOIDCInformation", + "/zitadel.auth.v1.AuthService/Healthz", + } +) + +func (a AccessLog) IsAuthenticated() bool { + if a.NotCountable { + return false + } + if !a.normalized { + panic("access log not normalized, Normalize() must be called before IsAuthenticated()") + } + _, hasHTTPAuthHeader := a.RequestHeaders[strings.ToLower(zitadel_http.Authorization)] + // ignore requests, which were unauthorized or do not require an authorization (even if one was sent) + // also ignore if the limit was already reached or if the server returned an internal error + // not that endpoints paths are only checked with the gRPC representation as HTTP (gateway) will not log them + return hasHTTPAuthHeader && + (a.Protocol == HTTP && + a.ResponseStatus != http.StatusInternalServerError && + a.ResponseStatus != http.StatusTooManyRequests && + a.ResponseStatus != http.StatusUnauthorized) || + (a.Protocol == GRPC && + a.ResponseStatus != uint32(codes.Internal) && + a.ResponseStatus != uint32(codes.ResourceExhausted) && + a.ResponseStatus != uint32(codes.Unauthenticated) && + !a.isUnaccountableEndpoint()) +} + +func (a AccessLog) isUnaccountableEndpoint() bool { + for _, endpoint := range unaccountableEndpoints { + if strings.HasPrefix(a.RequestURL, endpoint) { + return true + } + } + return false +} + +func (a AccessLog) Normalize() *AccessLog { + a.RequestedDomain = cutString(a.RequestedDomain, 200) + a.RequestURL = cutString(a.RequestURL, 200) + a.RequestHeaders = normalizeHeaders(a.RequestHeaders, strings.ToLower(zitadel_http.Authorization), "grpcgateway-authorization", "cookie", "grpcgateway-cookie") + a.ResponseHeaders = normalizeHeaders(a.ResponseHeaders, "set-cookie") + a.normalized = true + return &a +} + +// normalizeHeaders lowers all header keys and redacts secrets +func normalizeHeaders(header map[string][]string, redactKeysLower ...string) map[string][]string { + return pruneKeys(redactKeys(lowerKeys(header), redactKeysLower...)) +} + +func lowerKeys(header map[string][]string) map[string][]string { + lower := make(map[string][]string, len(header)) + for k, v := range header { + lower[strings.ToLower(k)] = v + } + return lower +} + +func redactKeys(header map[string][]string, redactKeysLower ...string) map[string][]string { + redactedKeys := make(map[string][]string, len(header)) + for k, v := range header { + redactedKeys[k] = v + } + for _, redactKey := range redactKeysLower { + if _, ok := redactedKeys[redactKey]; ok { + redactedKeys[redactKey] = []string{redacted} + } + } + return redactedKeys +} + +const maxValuesPerKey = 10 + +func pruneKeys(header map[string][]string) map[string][]string { + prunedKeys := make(map[string][]string, len(header)) + for key, value := range header { + valueItems := make([]string, 0, maxValuesPerKey) + for i, valueItem := range value { + // Max 10 header values per key + if i > maxValuesPerKey { + break + } + // Max 200 value length + valueItems = append(valueItems, cutString(valueItem, 200)) + } + prunedKeys[key] = valueItems + } + return prunedKeys +} diff --git a/internal/logstore/emitters/access/record_test.go b/internal/logstore/record/access_test.go similarity index 85% rename from internal/logstore/emitters/access/record_test.go rename to internal/logstore/record/access_test.go index 40cc835beb..23ffff5155 100644 --- a/internal/logstore/emitters/access/record_test.go +++ b/internal/logstore/record/access_test.go @@ -1,20 +1,18 @@ -package access_test +package record import ( "reflect" "testing" - - "github.com/zitadel/zitadel/internal/logstore/emitters/access" ) func TestRecord_Normalize(t *testing.T) { tests := []struct { name string - record access.Record - want *access.Record + record AccessLog + want *AccessLog }{{ name: "headers with certain keys should be redacted", - record: access.Record{ + record: AccessLog{ RequestHeaders: map[string][]string{ "authorization": {"AValue"}, "grpcgateway-authorization": {"AValue"}, @@ -24,7 +22,7 @@ func TestRecord_Normalize(t *testing.T) { "set-cookie": {"AValue"}, }, }, - want: &access.Record{ + want: &AccessLog{ RequestHeaders: map[string][]string{ "authorization": {"[REDACTED]"}, "grpcgateway-authorization": {"[REDACTED]"}, @@ -36,22 +34,22 @@ func TestRecord_Normalize(t *testing.T) { }, }, { name: "header keys should be lower cased", - record: access.Record{ + record: AccessLog{ RequestHeaders: map[string][]string{"AKey": {"AValue"}}, ResponseHeaders: map[string][]string{"AKey": {"AValue"}}}, - want: &access.Record{ + want: &AccessLog{ RequestHeaders: map[string][]string{"akey": {"AValue"}}, ResponseHeaders: map[string][]string{"akey": {"AValue"}}}, }, { name: "an already prune record should stay unchanged", - record: access.Record{ + record: AccessLog{ RequestURL: "https://my.zitadel.cloud/", RequestHeaders: map[string][]string{ "authorization": {"[REDACTED]"}, }, ResponseHeaders: map[string][]string{}, }, - want: &access.Record{ + want: &AccessLog{ RequestURL: "https://my.zitadel.cloud/", RequestHeaders: map[string][]string{ "authorization": {"[REDACTED]"}, @@ -60,17 +58,18 @@ func TestRecord_Normalize(t *testing.T) { }, }, { name: "empty record should stay empty", - record: access.Record{ + record: AccessLog{ RequestHeaders: map[string][]string{}, ResponseHeaders: map[string][]string{}, }, - want: &access.Record{ + want: &AccessLog{ RequestHeaders: map[string][]string{}, ResponseHeaders: map[string][]string{}, }, }} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + tt.want.normalized = true if got := tt.record.Normalize(); !reflect.DeepEqual(got, tt.want) { t.Errorf("Normalize() = %v, want %v", got, tt.want) } diff --git a/internal/logstore/emitters/execution/record.go b/internal/logstore/record/execution.go similarity index 63% rename from internal/logstore/emitters/execution/record.go rename to internal/logstore/record/execution.go index c45790622c..ee53cbeab5 100644 --- a/internal/logstore/emitters/execution/record.go +++ b/internal/logstore/record/execution.go @@ -1,16 +1,12 @@ -package execution +package record import ( "time" "github.com/sirupsen/logrus" - - "github.com/zitadel/zitadel/internal/logstore" ) -var _ logstore.LogRecord = (*Record)(nil) - -type Record struct { +type ExecutionLog struct { LogDate time.Time `json:"logDate"` Took time.Duration `json:"took"` Message string `json:"message"` @@ -20,14 +16,7 @@ type Record struct { Metadata map[string]interface{} `json:"metadata,omitempty"` } -func (e Record) Normalize() logstore.LogRecord { +func (e ExecutionLog) Normalize() *ExecutionLog { e.Message = cutString(e.Message, 2000) return &e } - -func cutString(str string, pos int) string { - if len(str) <= pos { - return str - } - return str[:pos] -} diff --git a/internal/logstore/record/prune.go b/internal/logstore/record/prune.go new file mode 100644 index 0000000000..47e08343bb --- /dev/null +++ b/internal/logstore/record/prune.go @@ -0,0 +1,8 @@ +package record + +func cutString(str string, pos int) string { + if len(str) <= pos { + return str + } + return str[:pos-1] +} diff --git a/internal/logstore/service.go b/internal/logstore/service.go index 542b50dd54..03620336e4 100644 --- a/internal/logstore/service.go +++ b/internal/logstore/service.go @@ -2,124 +2,70 @@ package logstore import ( "context" - "math" - "time" "github.com/zitadel/logging" - "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/repository/quota" ) -const handleThresholdTimeout = time.Minute - -type QuotaQuerier interface { - GetCurrentQuotaPeriod(ctx context.Context, instanceID string, unit quota.Unit) (config *quota.AddedEvent, periodStart time.Time, err error) - GetDueQuotaNotifications(ctx context.Context, config *quota.AddedEvent, periodStart time.Time, used uint64) ([]*quota.NotificationDueEvent, error) -} - -type UsageQuerier interface { - LogEmitter +type UsageStorer[T LogRecord[T]] interface { + LogEmitter[T] QuotaUnit() quota.Unit - QueryUsage(ctx context.Context, instanceId string, start time.Time) (uint64, error) } -type UsageReporter interface { - Report(ctx context.Context, notifications []*quota.NotificationDueEvent) (err error) -} - -type UsageReporterFunc func(context.Context, []*quota.NotificationDueEvent) (err error) - -func (u UsageReporterFunc) Report(ctx context.Context, notifications []*quota.NotificationDueEvent) (err error) { - return u(ctx, notifications) -} - -type Service struct { - usageQuerier UsageQuerier - quotaQuerier QuotaQuerier - usageReporter UsageReporter - enabledSinks []*emitter +type Service[T LogRecord[T]] struct { + queries Queries + usageStorer UsageStorer[T] + enabledSinks []*emitter[T] sinkEnabled bool reportingEnabled bool } -func New(quotaQuerier QuotaQuerier, usageReporter UsageReporter, usageQuerierSink *emitter, additionalSink ...*emitter) *Service { - var usageQuerier UsageQuerier +type Queries interface { + GetRemainingQuotaUsage(ctx context.Context, instanceID string, unit quota.Unit) (remaining *uint64, err error) +} + +func New[T LogRecord[T]](queries Queries, usageQuerierSink *emitter[T], additionalSink ...*emitter[T]) *Service[T] { + var usageStorer UsageStorer[T] if usageQuerierSink != nil { - usageQuerier = usageQuerierSink.emitter.(UsageQuerier) + usageStorer = usageQuerierSink.emitter.(UsageStorer[T]) } - - svc := &Service{ + svc := &Service[T]{ + queries: queries, reportingEnabled: usageQuerierSink != nil && usageQuerierSink.enabled, - usageQuerier: usageQuerier, - quotaQuerier: quotaQuerier, - usageReporter: usageReporter, + usageStorer: usageStorer, } - - for _, s := range append([]*emitter{usageQuerierSink}, additionalSink...) { + for _, s := range append([]*emitter[T]{usageQuerierSink}, additionalSink...) { if s != nil && s.enabled { svc.enabledSinks = append(svc.enabledSinks, s) } } - svc.sinkEnabled = len(svc.enabledSinks) > 0 - return svc } -func (s *Service) Enabled() bool { +func (s *Service[T]) Enabled() bool { return s.sinkEnabled } -func (s *Service) Handle(ctx context.Context, record LogRecord) { +func (s *Service[T]) Handle(ctx context.Context, record T) { for _, sink := range s.enabledSinks { logging.OnError(sink.Emit(ctx, record.Normalize())).WithField("record", record).Warn("failed to emit log record") } } -func (s *Service) Limit(ctx context.Context, instanceID string) *uint64 { +func (s *Service[T]) Limit(ctx context.Context, instanceID string) *uint64 { var err error defer func() { - logging.OnError(err).Warn("failed to check is usage should be limited") + logging.OnError(err).Warn("failed to check if usage should be limited") }() - if !s.reportingEnabled || instanceID == "" { return nil } - - quota, periodStart, err := s.quotaQuerier.GetCurrentQuotaPeriod(ctx, instanceID, s.usageQuerier.QuotaUnit()) - if err != nil || quota == nil { - return nil - } - - usage, err := s.usageQuerier.QueryUsage(ctx, instanceID, periodStart) + remaining, err := s.queries.GetRemainingQuotaUsage(ctx, instanceID, s.usageStorer.QuotaUnit()) if err != nil { + // TODO: shouldn't we just limit then or return the error and decide there? return nil } - - go s.handleThresholds(ctx, quota, periodStart, usage) - - var remaining *uint64 - if quota.Limit { - r := uint64(math.Max(0, float64(quota.Amount)-float64(usage))) - remaining = &r - } return remaining } - -func (s *Service) handleThresholds(ctx context.Context, quota *quota.AddedEvent, periodStart time.Time, usage uint64) { - var err error - defer func() { - logging.OnError(err).Warn("handling quota thresholds failed") - }() - - detatchedCtx, cancel := context.WithTimeout(authz.Detach(ctx), handleThresholdTimeout) - defer cancel() - - notifications, err := s.quotaQuerier.GetDueQuotaNotifications(detatchedCtx, quota, periodStart, usage) - if err != nil || len(notifications) == 0 { - return - } - - err = s.usageReporter.Report(detatchedCtx, notifications) -} diff --git a/internal/logstore/service_test.go b/internal/logstore/service_test.go index ef6fcd128e..d99e50161d 100644 --- a/internal/logstore/service_test.go +++ b/internal/logstore/service_test.go @@ -14,9 +14,8 @@ import ( "github.com/benbjohnson/clock" "github.com/zitadel/zitadel/internal/logstore" - emittermock "github.com/zitadel/zitadel/internal/logstore/emitters/mock" - quotaqueriermock "github.com/zitadel/zitadel/internal/logstore/quotaqueriers/mock" - "github.com/zitadel/zitadel/internal/repository/quota" + emittermock "github.com/zitadel/zitadel/internal/logstore/mock" + "github.com/zitadel/zitadel/internal/query" ) const ( @@ -27,7 +26,7 @@ const ( type args struct { mainSink *logstore.EmitterConfig secondarySink *logstore.EmitterConfig - config quota.AddedEvent + config *query.Quota } type want struct { @@ -137,28 +136,6 @@ func TestService(t *testing.T) { len: 0, }, }, - }, { - name: "cleanupping works", - args: args{ - mainSink: emitterConfig(withCleanupping(17*time.Second, 28*time.Second)), - secondarySink: emitterConfig(withDebouncerConfig(&logstore.DebouncerConfig{ - MinFrequency: 0, - MaxBulkSize: 15, - }), withCleanupping(5*time.Second, 47*time.Second)), - config: quotaConfig(), - }, - want: want{ - enabled: true, - remaining: nil, - mainSink: wantSink{ - bulks: repeat(1, 60), - len: 21, - }, - secondarySink: wantSink{ - bulks: repeat(15, 4), - len: 18, - }, - }, }, { name: "when quota has a limit of 90, 30 are remaining", args: args{ @@ -232,27 +209,24 @@ func runTest(t *testing.T, name string, args args, want want) bool { }) } -func given(t *testing.T, args args, want want) (context.Context, *clock.Mock, *emittermock.InmemLogStorage, *emittermock.InmemLogStorage, *logstore.Service) { +func given(t *testing.T, args args, want want) (context.Context, *clock.Mock, *emittermock.InmemLogStorage, *emittermock.InmemLogStorage, *logstore.Service[*emittermock.Record]) { ctx := context.Background() clock := clock.NewMock() - periodStart := time.Time{} clock.Set(args.config.From) - mainStorage := emittermock.NewInMemoryStorage(clock) - mainEmitter, err := logstore.NewEmitter(ctx, clock, args.mainSink, mainStorage) + mainStorage := emittermock.NewInMemoryStorage(clock, args.config) + mainEmitter, err := logstore.NewEmitter[*emittermock.Record](ctx, clock, args.mainSink, mainStorage) if err != nil { t.Errorf("expected no error but got %v", err) } - secondaryStorage := emittermock.NewInMemoryStorage(clock) - secondaryEmitter, err := logstore.NewEmitter(ctx, clock, args.secondarySink, secondaryStorage) + secondaryStorage := emittermock.NewInMemoryStorage(clock, args.config) + secondaryEmitter, err := logstore.NewEmitter[*emittermock.Record](ctx, clock, args.secondarySink, secondaryStorage) if err != nil { t.Errorf("expected no error but got %v", err) } - - svc := logstore.New( - quotaqueriermock.NewNoopQuerier(&args.config, periodStart), - logstore.UsageReporterFunc(func(context.Context, []*quota.NotificationDueEvent) error { return nil }), + svc := logstore.New[*emittermock.Record]( + mainStorage, mainEmitter, secondaryEmitter) @@ -262,7 +236,7 @@ func given(t *testing.T, args args, want want) (context.Context, *clock.Mock, *e return ctx, clock, mainStorage, secondaryStorage, svc } -func when(svc *logstore.Service, ctx context.Context, clock *clock.Mock) *uint64 { +func when(svc *logstore.Service[*emittermock.Record], ctx context.Context, clock *clock.Mock) *uint64 { var remaining *uint64 for i := 0; i < ticks; i++ { svc.Handle(ctx, emittermock.NewRecord(clock)) diff --git a/cmd/start/start_port.go b/internal/net/start_port.go similarity index 58% rename from cmd/start/start_port.go rename to internal/net/start_port.go index bb60fea250..df52540e5b 100644 --- a/cmd/start/start_port.go +++ b/internal/net/start_port.go @@ -1,11 +1,11 @@ //go:build !integration -package start +package net import ( "net" ) -func listenConfig() *net.ListenConfig { +func ListenConfig() *net.ListenConfig { return &net.ListenConfig{} } diff --git a/cmd/start/start_port_integration.go b/internal/net/start_port_integration.go similarity index 87% rename from cmd/start/start_port_integration.go rename to internal/net/start_port_integration.go index 1c5763d6e9..d2e7e5802c 100644 --- a/cmd/start/start_port_integration.go +++ b/internal/net/start_port_integration.go @@ -1,6 +1,6 @@ //go:build integration -package start +package net import ( "net" @@ -9,7 +9,7 @@ import ( "golang.org/x/sys/unix" ) -func listenConfig() *net.ListenConfig { +func ListenConfig() *net.ListenConfig { return &net.ListenConfig{ Control: reusePort, } diff --git a/internal/notification/handlers/handlers_integration_test.go b/internal/notification/handlers/handlers_integration_test.go index 37f6a344c9..609f01b3f4 100644 --- a/internal/notification/handlers/handlers_integration_test.go +++ b/internal/notification/handlers/handlers_integration_test.go @@ -20,11 +20,13 @@ var ( func TestMain(m *testing.M) { os.Exit(func() int { ctx, _, cancel := integration.Contexts(5 * time.Minute) - CTX = ctx defer cancel() + CTX = ctx + Tester = integration.NewTester(ctx) - SystemCTX = Tester.WithAuthorization(ctx, integration.SystemUser) defer Tester.Done() + + SystemCTX = Tester.WithAuthorization(ctx, integration.SystemUser) return m.Run() }()) } diff --git a/internal/notification/handlers/telemetry_pusher_integration_test.go b/internal/notification/handlers/telemetry_pusher_integration_test.go index e3709db171..f0d46b3613 100644 --- a/internal/notification/handlers/telemetry_pusher_integration_test.go +++ b/internal/notification/handlers/telemetry_pusher_integration_test.go @@ -5,11 +5,6 @@ package handlers_test import ( "bytes" "encoding/json" - "io" - "net" - "net/http" - "net/http/httptest" - "reflect" "testing" "time" @@ -18,49 +13,27 @@ import ( ) func TestServer_TelemetryPushMilestones(t *testing.T) { - bodies := make(chan []byte, 0) - mockServer := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - body, err := io.ReadAll(r.Body) - if err != nil { - t.Error(err) - } - if r.Header.Get("single-value") != "single-value" { - t.Error("single-value header not set") - } - if reflect.DeepEqual(r.Header.Get("multi-value"), "multi-value-1,multi-value-2") { - t.Error("single-value header not set") - } - bodies <- body - w.WriteHeader(http.StatusOK) - })) - listener, err := net.Listen("tcp", "localhost:8081") - if err != nil { - t.Fatal(err) - } - mockServer.Listener = listener - mockServer.Start() - t.Cleanup(mockServer.Close) primaryDomain, instanceID, iamOwnerCtx := Tester.UseIsolatedInstance(CTX, SystemCTX) t.Log("testing against instance with primary domain", primaryDomain) - awaitMilestone(t, bodies, primaryDomain, "InstanceCreated") + awaitMilestone(t, Tester.MilestoneChan, primaryDomain, "InstanceCreated") project, err := Tester.Client.Mgmt.AddProject(iamOwnerCtx, &management.AddProjectRequest{Name: "integration"}) if err != nil { t.Fatal(err) } - awaitMilestone(t, bodies, primaryDomain, "ProjectCreated") + awaitMilestone(t, Tester.MilestoneChan, primaryDomain, "ProjectCreated") if _, err = Tester.Client.Mgmt.AddOIDCApp(iamOwnerCtx, &management.AddOIDCAppRequest{ ProjectId: project.GetId(), Name: "integration", }); err != nil { t.Fatal(err) } - awaitMilestone(t, bodies, primaryDomain, "ApplicationCreated") + awaitMilestone(t, Tester.MilestoneChan, primaryDomain, "ApplicationCreated") // TODO: trigger and await milestone AuthenticationSucceededOnInstance // TODO: trigger and await milestone AuthenticationSucceededOnApplication if _, err = Tester.Client.System.RemoveInstance(SystemCTX, &system.RemoveInstanceRequest{InstanceId: instanceID}); err != nil { t.Fatal(err) } - awaitMilestone(t, bodies, primaryDomain, "InstanceDeleted") + awaitMilestone(t, Tester.MilestoneChan, primaryDomain, "InstanceDeleted") } func awaitMilestone(t *testing.T, bodies chan []byte, primaryDomain, expectMilestoneType string) { diff --git a/internal/query/projection/main_test.go b/internal/query/projection/main_test.go index 1c09cf3ffa..2cf61fe540 100644 --- a/internal/query/projection/main_test.go +++ b/internal/query/projection/main_test.go @@ -13,6 +13,7 @@ import ( key_repo "github.com/zitadel/zitadel/internal/repository/keypair" "github.com/zitadel/zitadel/internal/repository/org" proj_repo "github.com/zitadel/zitadel/internal/repository/project" + quota_repo "github.com/zitadel/zitadel/internal/repository/quota" usr_repo "github.com/zitadel/zitadel/internal/repository/user" "github.com/zitadel/zitadel/internal/repository/usergrant" ) @@ -29,6 +30,7 @@ func eventstoreExpect(t *testing.T, expects ...expect) *eventstore.Eventstore { org.RegisterEventMappers(es) usr_repo.RegisterEventMappers(es) proj_repo.RegisterEventMappers(es) + quota_repo.RegisterEventMappers(es) usergrant.RegisterEventMappers(es) key_repo.RegisterEventMappers(es) action_repo.RegisterEventMappers(es) diff --git a/internal/query/projection/projection.go b/internal/query/projection/projection.go index e8f2bd1563..19a0c1cc0f 100644 --- a/internal/query/projection/projection.go +++ b/internal/query/projection/projection.go @@ -69,6 +69,7 @@ var ( SessionProjection *sessionProjection AuthRequestProjection *authRequestProjection MilestoneProjection *milestoneProjection + QuotaProjection *quotaProjection ) type projection interface { @@ -148,6 +149,7 @@ func Create(ctx context.Context, sqlClient *database.DB, es *eventstore.Eventsto SessionProjection = newSessionProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["sessions"])) AuthRequestProjection = newAuthRequestProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["auth_requests"])) MilestoneProjection = newMilestoneProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["milestones"])) + QuotaProjection = newQuotaProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["quotas"])) newProjectionsList() return nil } @@ -247,5 +249,6 @@ func newProjectionsList() { SessionProjection, AuthRequestProjection, MilestoneProjection, + QuotaProjection, } } diff --git a/internal/query/projection/quota.go b/internal/query/projection/quota.go new file mode 100644 index 0000000000..fd6e12266f --- /dev/null +++ b/internal/query/projection/quota.go @@ -0,0 +1,285 @@ +package projection + +import ( + "context" + "time" + + "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/eventstore/handler" + "github.com/zitadel/zitadel/internal/eventstore/handler/crdb" + "github.com/zitadel/zitadel/internal/repository/instance" + "github.com/zitadel/zitadel/internal/repository/quota" +) + +const ( + QuotasProjectionTable = "projections.quotas" + QuotaPeriodsProjectionTable = QuotasProjectionTable + "_" + quotaPeriodsTableSuffix + QuotaNotificationsTable = QuotasProjectionTable + "_" + quotaNotificationsTableSuffix + + QuotaColumnID = "id" + QuotaColumnInstanceID = "instance_id" + QuotaColumnUnit = "unit" + QuotaColumnAmount = "amount" + QuotaColumnFrom = "from_anchor" + QuotaColumnInterval = "interval" + QuotaColumnLimit = "limit_usage" + + quotaPeriodsTableSuffix = "periods" + QuotaPeriodColumnInstanceID = "instance_id" + QuotaPeriodColumnUnit = "unit" + QuotaPeriodColumnStart = "start" + QuotaPeriodColumnUsage = "usage" + + quotaNotificationsTableSuffix = "notifications" + QuotaNotificationColumnInstanceID = "instance_id" + QuotaNotificationColumnUnit = "unit" + QuotaNotificationColumnID = "id" + QuotaNotificationColumnCallURL = "call_url" + QuotaNotificationColumnPercent = "percent" + QuotaNotificationColumnRepeat = "repeat" + QuotaNotificationColumnLatestDuePeriodStart = "latest_due_period_start" + QuotaNotificationColumnNextDueThreshold = "next_due_threshold" +) + +const ( + incrementQuotaStatement = `INSERT INTO projections.quotas_periods` + + ` (instance_id, unit, start, usage)` + + ` VALUES ($1, $2, $3, $4) ON CONFLICT (instance_id, unit, start)` + + ` DO UPDATE SET usage = projections.quotas_periods.usage + excluded.usage RETURNING usage` +) + +type quotaProjection struct { + crdb.StatementHandler + client *database.DB +} + +func newQuotaProjection(ctx context.Context, config crdb.StatementHandlerConfig) *quotaProjection { + p := new(quotaProjection) + config.ProjectionName = QuotasProjectionTable + config.Reducers = p.reducers() + config.InitCheck = crdb.NewMultiTableCheck( + crdb.NewTable( + []*crdb.Column{ + crdb.NewColumn(QuotaColumnID, crdb.ColumnTypeText), + crdb.NewColumn(QuotaColumnInstanceID, crdb.ColumnTypeText), + crdb.NewColumn(QuotaColumnUnit, crdb.ColumnTypeEnum), + crdb.NewColumn(QuotaColumnAmount, crdb.ColumnTypeInt64), + crdb.NewColumn(QuotaColumnFrom, crdb.ColumnTypeTimestamp), + crdb.NewColumn(QuotaColumnInterval, crdb.ColumnTypeInterval), + crdb.NewColumn(QuotaColumnLimit, crdb.ColumnTypeBool), + }, + crdb.NewPrimaryKey(QuotaColumnInstanceID, QuotaColumnUnit), + ), + crdb.NewSuffixedTable( + []*crdb.Column{ + crdb.NewColumn(QuotaPeriodColumnInstanceID, crdb.ColumnTypeText), + crdb.NewColumn(QuotaPeriodColumnUnit, crdb.ColumnTypeEnum), + crdb.NewColumn(QuotaPeriodColumnStart, crdb.ColumnTypeTimestamp), + crdb.NewColumn(QuotaPeriodColumnUsage, crdb.ColumnTypeInt64), + }, + crdb.NewPrimaryKey(QuotaPeriodColumnInstanceID, QuotaPeriodColumnUnit, QuotaPeriodColumnStart), + quotaPeriodsTableSuffix, + ), + crdb.NewSuffixedTable( + []*crdb.Column{ + crdb.NewColumn(QuotaNotificationColumnInstanceID, crdb.ColumnTypeText), + crdb.NewColumn(QuotaNotificationColumnUnit, crdb.ColumnTypeEnum), + crdb.NewColumn(QuotaNotificationColumnID, crdb.ColumnTypeText), + crdb.NewColumn(QuotaNotificationColumnCallURL, crdb.ColumnTypeText), + crdb.NewColumn(QuotaNotificationColumnPercent, crdb.ColumnTypeInt64), + crdb.NewColumn(QuotaNotificationColumnRepeat, crdb.ColumnTypeBool), + crdb.NewColumn(QuotaNotificationColumnLatestDuePeriodStart, crdb.ColumnTypeTimestamp, crdb.Nullable()), + crdb.NewColumn(QuotaNotificationColumnNextDueThreshold, crdb.ColumnTypeInt64, crdb.Nullable()), + }, + crdb.NewPrimaryKey(QuotaNotificationColumnInstanceID, QuotaNotificationColumnUnit, QuotaNotificationColumnID), + quotaNotificationsTableSuffix, + ), + ) + p.StatementHandler = crdb.NewStatementHandler(ctx, config) + p.client = config.Client + return p +} + +func (q *quotaProjection) reducers() []handler.AggregateReducer { + return []handler.AggregateReducer{ + { + Aggregate: instance.AggregateType, + EventRedusers: []handler.EventReducer{ + { + Event: instance.InstanceRemovedEventType, + Reduce: q.reduceInstanceRemoved, + }, + }, + }, + { + Aggregate: quota.AggregateType, + EventRedusers: []handler.EventReducer{ + { + Event: quota.AddedEventType, + Reduce: q.reduceQuotaAdded, + }, + }, + }, + { + Aggregate: quota.AggregateType, + EventRedusers: []handler.EventReducer{ + { + Event: quota.RemovedEventType, + Reduce: q.reduceQuotaRemoved, + }, + }, + }, + { + Aggregate: quota.AggregateType, + EventRedusers: []handler.EventReducer{ + { + Event: quota.NotificationDueEventType, + Reduce: q.reduceQuotaNotificationDue, + }, + }, + }, + { + Aggregate: quota.AggregateType, + EventRedusers: []handler.EventReducer{ + { + Event: quota.NotifiedEventType, + Reduce: q.reduceQuotaNotified, + }, + }, + }, + } +} + +func (q *quotaProjection) reduceQuotaNotified(event eventstore.Event) (*handler.Statement, error) { + return crdb.NewNoOpStatement(event), nil +} + +func (q *quotaProjection) reduceQuotaAdded(event eventstore.Event) (*handler.Statement, error) { + e, err := assertEvent[*quota.AddedEvent](event) + if err != nil { + return nil, err + } + + createStatements := make([]func(e eventstore.Event) crdb.Exec, len(e.Notifications)+1) + createStatements[0] = crdb.AddCreateStatement( + []handler.Column{ + handler.NewCol(QuotaColumnID, e.Aggregate().ID), + handler.NewCol(QuotaColumnInstanceID, e.Aggregate().InstanceID), + handler.NewCol(QuotaColumnUnit, e.Unit), + handler.NewCol(QuotaColumnAmount, e.Amount), + handler.NewCol(QuotaColumnFrom, e.From), + handler.NewCol(QuotaColumnInterval, e.ResetInterval), + handler.NewCol(QuotaColumnLimit, e.Limit), + }) + for i := range e.Notifications { + notification := e.Notifications[i] + createStatements[i+1] = crdb.AddCreateStatement( + []handler.Column{ + handler.NewCol(QuotaNotificationColumnInstanceID, e.Aggregate().InstanceID), + handler.NewCol(QuotaNotificationColumnUnit, e.Unit), + handler.NewCol(QuotaNotificationColumnID, notification.ID), + handler.NewCol(QuotaNotificationColumnCallURL, notification.CallURL), + handler.NewCol(QuotaNotificationColumnPercent, notification.Percent), + handler.NewCol(QuotaNotificationColumnRepeat, notification.Repeat), + }, + crdb.WithTableSuffix(quotaNotificationsTableSuffix), + ) + } + + return crdb.NewMultiStatement(e, createStatements...), nil +} + +func (q *quotaProjection) reduceQuotaNotificationDue(event eventstore.Event) (*handler.Statement, error) { + e, err := assertEvent[*quota.NotificationDueEvent](event) + if err != nil { + return nil, err + } + return crdb.NewUpdateStatement(e, + []handler.Column{ + handler.NewCol(QuotaNotificationColumnLatestDuePeriodStart, e.PeriodStart), + handler.NewCol(QuotaNotificationColumnNextDueThreshold, e.Threshold+100), // next due_threshold is always the reached + 100 => percent (e.g. 90) in the next bucket (e.g. 190) + }, + []handler.Condition{ + handler.NewCond(QuotaNotificationColumnInstanceID, e.Aggregate().InstanceID), + handler.NewCond(QuotaNotificationColumnUnit, e.Unit), + handler.NewCond(QuotaNotificationColumnID, e.ID), + }, + crdb.WithTableSuffix(quotaNotificationsTableSuffix), + ), nil +} + +func (q *quotaProjection) reduceQuotaRemoved(event eventstore.Event) (*handler.Statement, error) { + e, err := assertEvent[*quota.RemovedEvent](event) + if err != nil { + return nil, err + } + return crdb.NewMultiStatement( + e, + crdb.AddDeleteStatement( + []handler.Condition{ + handler.NewCond(QuotaPeriodColumnInstanceID, e.Aggregate().InstanceID), + handler.NewCond(QuotaPeriodColumnUnit, e.Unit), + }, + crdb.WithTableSuffix(quotaPeriodsTableSuffix), + ), + crdb.AddDeleteStatement( + []handler.Condition{ + handler.NewCond(QuotaNotificationColumnInstanceID, e.Aggregate().InstanceID), + handler.NewCond(QuotaNotificationColumnUnit, e.Unit), + }, + crdb.WithTableSuffix(quotaNotificationsTableSuffix), + ), + crdb.AddDeleteStatement( + []handler.Condition{ + handler.NewCond(QuotaColumnInstanceID, e.Aggregate().InstanceID), + handler.NewCond(QuotaColumnUnit, e.Unit), + }, + ), + ), nil +} + +func (q *quotaProjection) reduceInstanceRemoved(event eventstore.Event) (*handler.Statement, error) { + // we only assert the event to make sure it is the correct type + e, err := assertEvent[*instance.InstanceRemovedEvent](event) + if err != nil { + return nil, err + } + return crdb.NewMultiStatement( + e, + crdb.AddDeleteStatement( + []handler.Condition{ + handler.NewCond(QuotaPeriodColumnInstanceID, e.Aggregate().InstanceID), + }, + crdb.WithTableSuffix(quotaPeriodsTableSuffix), + ), + crdb.AddDeleteStatement( + []handler.Condition{ + handler.NewCond(QuotaNotificationColumnInstanceID, e.Aggregate().InstanceID), + }, + crdb.WithTableSuffix(quotaNotificationsTableSuffix), + ), + crdb.AddDeleteStatement( + []handler.Condition{ + handler.NewCond(QuotaColumnInstanceID, e.Aggregate().InstanceID), + }, + ), + ), nil +} + +func (q *quotaProjection) IncrementUsage(ctx context.Context, unit quota.Unit, instanceID string, periodStart time.Time, count uint64) (sum uint64, err error) { + if count == 0 { + return 0, nil + } + + err = q.client.DB.QueryRowContext( + ctx, + incrementQuotaStatement, + instanceID, unit, periodStart, count, + ).Scan(&sum) + if err != nil { + return 0, errors.ThrowInternalf(err, "PROJ-SJL3h", "incrementing usage for unit %d failed for at least one quota period", unit) + } + return sum, err +} diff --git a/internal/query/projection/quota_test.go b/internal/query/projection/quota_test.go new file mode 100644 index 0000000000..dc9184da07 --- /dev/null +++ b/internal/query/projection/quota_test.go @@ -0,0 +1,321 @@ +package projection + +import ( + "context" + "regexp" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + + "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/eventstore/handler" + "github.com/zitadel/zitadel/internal/eventstore/repository" + "github.com/zitadel/zitadel/internal/repository/instance" + "github.com/zitadel/zitadel/internal/repository/quota" +) + +func TestQuotasProjection_reduces(t *testing.T) { + type args struct { + event func(t *testing.T) eventstore.Event + } + tests := []struct { + name string + args args + reduce func(event eventstore.Event) (*handler.Statement, error) + want wantReduce + }{ + { + name: "reduceQuotaAdded", + args: args{ + event: getEvent(testEvent( + repository.EventType(quota.AddedEventType), + quota.AggregateType, + []byte(`{ + "unit": 1, + "amount": 10, + "limit": true, + "from": "2023-01-01T00:00:00Z", + "interval": 300000000000 + }`), + ), quota.AddedEventMapper), + }, + reduce: ("aProjection{}).reduceQuotaAdded, + want: wantReduce{ + aggregateType: eventstore.AggregateType("quota"), + sequence: 15, + previousSequence: 10, + executer: &testExecuter{ + executions: []execution{ + { + expectedStmt: "INSERT INTO projections.quotas (id, instance_id, unit, amount, from_anchor, interval, limit_usage) VALUES ($1, $2, $3, $4, $5, $6, $7)", + expectedArgs: []interface{}{ + "agg-id", + "instance-id", + quota.RequestsAllAuthenticated, + uint64(10), + time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC), + time.Minute * 5, + true, + }, + }, + }, + }, + }, + }, + { + name: "reduceQuotaAdded with notification", + args: args{ + event: getEvent(testEvent( + repository.EventType(quota.AddedEventType), + quota.AggregateType, + []byte(`{ + "unit": 1, + "amount": 10, + "limit": true, + "from": "2023-01-01T00:00:00Z", + "interval": 300000000000, + "notifications": [ + { + "id": "id", + "percent": 100, + "repeat": true, + "callURL": "url" + } + ] + }`), + ), quota.AddedEventMapper), + }, + reduce: ("aProjection{}).reduceQuotaAdded, + want: wantReduce{ + aggregateType: eventstore.AggregateType("quota"), + sequence: 15, + previousSequence: 10, + executer: &testExecuter{ + executions: []execution{ + { + expectedStmt: "INSERT INTO projections.quotas (id, instance_id, unit, amount, from_anchor, interval, limit_usage) VALUES ($1, $2, $3, $4, $5, $6, $7)", + expectedArgs: []interface{}{ + "agg-id", + "instance-id", + quota.RequestsAllAuthenticated, + uint64(10), + time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC), + time.Minute * 5, + true, + }, + }, + { + expectedStmt: "INSERT INTO projections.quotas_notifications (instance_id, unit, id, call_url, percent, repeat) VALUES ($1, $2, $3, $4, $5, $6)", + expectedArgs: []interface{}{ + "instance-id", + quota.RequestsAllAuthenticated, + "id", + "url", + uint16(100), + true, + }, + }, + }, + }, + }, + }, + { + name: "reduceQuotaNotificationDue", + args: args{ + event: getEvent(testEvent( + repository.EventType(quota.NotificationDueEventType), + quota.AggregateType, + []byte(`{ + "id": "id", + "unit": 1, + "callURL": "url", + "periodStart": "2023-01-01T00:00:00Z", + "threshold": 200, + "usage": 100 + }`), + ), quota.NotificationDueEventMapper), + }, + reduce: ("aProjection{}).reduceQuotaNotificationDue, + want: wantReduce{ + aggregateType: eventstore.AggregateType("quota"), + sequence: 15, + previousSequence: 10, + executer: &testExecuter{ + executions: []execution{ + { + expectedStmt: "UPDATE projections.quotas_notifications SET (latest_due_period_start, next_due_threshold) = ($1, $2) WHERE (instance_id = $3) AND (unit = $4) AND (id = $5)", + expectedArgs: []interface{}{ + time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC), + uint16(300), + "instance-id", + quota.RequestsAllAuthenticated, + "id", + }, + }, + }, + }, + }, + }, + { + name: "reduceQuotaRemoved", + args: args{ + event: getEvent(testEvent( + repository.EventType(quota.RemovedEventType), + quota.AggregateType, + []byte(`{ + "unit": 1 + }`), + ), quota.RemovedEventMapper), + }, + reduce: ("aProjection{}).reduceQuotaRemoved, + want: wantReduce{ + aggregateType: eventstore.AggregateType("quota"), + sequence: 15, + previousSequence: 10, + executer: &testExecuter{ + executions: []execution{ + { + expectedStmt: "DELETE FROM projections.quotas_periods WHERE (instance_id = $1) AND (unit = $2)", + expectedArgs: []interface{}{ + "instance-id", + quota.RequestsAllAuthenticated, + }, + }, + { + expectedStmt: "DELETE FROM projections.quotas_notifications WHERE (instance_id = $1) AND (unit = $2)", + expectedArgs: []interface{}{ + "instance-id", + quota.RequestsAllAuthenticated, + }, + }, + { + expectedStmt: "DELETE FROM projections.quotas WHERE (instance_id = $1) AND (unit = $2)", + expectedArgs: []interface{}{ + "instance-id", + quota.RequestsAllAuthenticated, + }, + }, + }, + }, + }, + }, { + name: "reduceInstanceRemoved", + args: args{ + event: getEvent(testEvent( + repository.EventType(instance.InstanceRemovedEventType), + instance.AggregateType, + []byte(`{ + "name": "name" + }`), + ), instance.InstanceRemovedEventMapper), + }, + reduce: ("aProjection{}).reduceInstanceRemoved, + want: wantReduce{ + aggregateType: eventstore.AggregateType("instance"), + sequence: 15, + previousSequence: 10, + executer: &testExecuter{ + executions: []execution{ + { + expectedStmt: "DELETE FROM projections.quotas_periods WHERE (instance_id = $1)", + expectedArgs: []interface{}{ + "instance-id", + }, + }, + { + expectedStmt: "DELETE FROM projections.quotas_notifications WHERE (instance_id = $1)", + expectedArgs: []interface{}{ + "instance-id", + }, + }, + { + expectedStmt: "DELETE FROM projections.quotas WHERE (instance_id = $1)", + expectedArgs: []interface{}{ + "instance-id", + }, + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + event := baseEvent(t) + got, err := tt.reduce(event) + if !errors.IsErrorInvalidArgument(err) { + t.Errorf("no wrong event mapping: %v, got: %v", err, got) + } + event = tt.args.event(t) + got, err = tt.reduce(event) + assertReduce(t, got, err, QuotasProjectionTable, tt.want) + }) + } +} + +func Test_quotaProjection_IncrementUsage(t *testing.T) { + testNow := time.Now() + type fields struct { + client *database.DB + } + type args struct { + ctx context.Context + unit quota.Unit + instanceID string + periodStart time.Time + count uint64 + } + type res struct { + sum uint64 + err error + } + tests := []struct { + name string + fields fields + args args + res res + }{ + { + name: "", + fields: fields{ + client: func() *database.DB { + db, mock, _ := sqlmock.New() + mock.ExpectQuery(regexp.QuoteMeta(incrementQuotaStatement)). + WithArgs( + "instance_id", + 1, + testNow, + 2, + ). + WillReturnRows(sqlmock.NewRows([]string{"key"}). + AddRow(3)) + return &database.DB{DB: db} + }(), + }, + args: args{ + ctx: context.Background(), + unit: quota.RequestsAllAuthenticated, + instanceID: "instance_id", + periodStart: testNow, + count: 2, + }, + res: res{ + sum: 3, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := "aProjection{ + client: tt.fields.client, + } + gotSum, err := q.IncrementUsage(tt.args.ctx, tt.args.unit, tt.args.instanceID, tt.args.periodStart, tt.args.count) + assert.Equal(t, tt.res.sum, gotSum) + assert.ErrorIs(t, err, tt.res.err) + }) + } +} diff --git a/internal/query/query.go b/internal/query/query.go index 9355265e1c..216426c430 100644 --- a/internal/query/query.go +++ b/internal/query/query.go @@ -26,6 +26,7 @@ import ( "github.com/zitadel/zitadel/internal/repository/oidcsession" "github.com/zitadel/zitadel/internal/repository/org" "github.com/zitadel/zitadel/internal/repository/project" + "github.com/zitadel/zitadel/internal/repository/quota" "github.com/zitadel/zitadel/internal/repository/session" usr_repo "github.com/zitadel/zitadel/internal/repository/user" "github.com/zitadel/zitadel/internal/repository/usergrant" @@ -93,6 +94,7 @@ func StartQueries( idpintent.RegisterEventMappers(repo.eventstore) authrequest.RegisterEventMappers(repo.eventstore) oidcsession.RegisterEventMappers(repo.eventstore) + quota.RegisterEventMappers(repo.eventstore) repo.idpConfigEncryption = idpConfigEncryption repo.multifactors = domain.MultifactorConfigs{ diff --git a/internal/query/quota.go b/internal/query/quota.go new file mode 100644 index 0000000000..1919902a27 --- /dev/null +++ b/internal/query/quota.go @@ -0,0 +1,121 @@ +package query + +import ( + "context" + "database/sql" + errs "errors" + "time" + + sq "github.com/Masterminds/squirrel" + + "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/query/projection" + "github.com/zitadel/zitadel/internal/repository/quota" + "github.com/zitadel/zitadel/internal/telemetry/tracing" +) + +var ( + quotasTable = table{ + name: projection.QuotasProjectionTable, + instanceIDCol: projection.QuotaColumnInstanceID, + } + QuotaColumnID = Column{ + name: projection.QuotaColumnID, + table: quotasTable, + } + QuotaColumnInstanceID = Column{ + name: projection.QuotaColumnInstanceID, + table: quotasTable, + } + QuotaColumnUnit = Column{ + name: projection.QuotaColumnUnit, + table: quotasTable, + } + QuotaColumnAmount = Column{ + name: projection.QuotaColumnAmount, + table: quotasTable, + } + QuotaColumnLimit = Column{ + name: projection.QuotaColumnLimit, + table: quotasTable, + } + QuotaColumnInterval = Column{ + name: projection.QuotaColumnInterval, + table: quotasTable, + } + QuotaColumnFrom = Column{ + name: projection.QuotaColumnFrom, + table: quotasTable, + } +) + +type Quota struct { + ID string + From time.Time + ResetInterval time.Duration + Amount uint64 + Limit bool + CurrentPeriodStart time.Time +} + +func (q *Queries) GetQuota(ctx context.Context, instanceID string, unit quota.Unit) (qu *Quota, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + query, scan := prepareQuotaQuery(ctx, q.client) + stmt, args, err := query.Where( + sq.Eq{ + QuotaColumnInstanceID.identifier(): instanceID, + QuotaColumnUnit.identifier(): unit, + }, + ).ToSql() + if err != nil { + return nil, errors.ThrowInternal(err, "QUERY-XmYn9", "Errors.Query.SQLStatement") + } + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + qu, err = scan(row) + return err + }, stmt, args...) + return qu, err +} + +func prepareQuotaQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Quota, error)) { + return sq. + Select( + QuotaColumnID.identifier(), + QuotaColumnFrom.identifier(), + QuotaColumnInterval.identifier(), + QuotaColumnAmount.identifier(), + QuotaColumnLimit.identifier(), + "now()", + ). + From(quotasTable.identifier()). + PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*Quota, error) { + q := new(Quota) + var interval database.Duration + var now time.Time + err := row.Scan(&q.ID, &q.From, &interval, &q.Amount, &q.Limit, &now) + if err != nil { + if errs.Is(err, sql.ErrNoRows) { + return nil, errors.ThrowNotFound(err, "QUERY-rDTM6", "Errors.Quota.NotExisting") + } + return nil, errors.ThrowInternal(err, "QUERY-LqySK", "Errors.Internal") + } + q.ResetInterval = time.Duration(interval) + q.CurrentPeriodStart = pushPeriodStart(q.From, q.ResetInterval, now) + return q, nil + } +} + +func pushPeriodStart(from time.Time, interval time.Duration, now time.Time) time.Time { + if now.IsZero() { + now = time.Now() + } + for { + next := from.Add(interval) + if next.After(now) { + return from + } + from = next + } +} diff --git a/internal/query/quota_model.go b/internal/query/quota_model.go deleted file mode 100644 index 322ea67aac..0000000000 --- a/internal/query/quota_model.go +++ /dev/null @@ -1,55 +0,0 @@ -package query - -import ( - "github.com/zitadel/zitadel/internal/eventstore" - "github.com/zitadel/zitadel/internal/repository/quota" -) - -type quotaReadModel struct { - eventstore.ReadModel - unit quota.Unit - active bool - config *quota.AddedEvent -} - -// newQuotaReadModel aggregateId is filled by reducing unit matching events -func newQuotaReadModel(instanceId, resourceOwner string, unit quota.Unit) *quotaReadModel { - return "aReadModel{ - ReadModel: eventstore.ReadModel{ - InstanceID: instanceId, - ResourceOwner: resourceOwner, - }, - unit: unit, - } -} - -func (rm *quotaReadModel) Query() *eventstore.SearchQueryBuilder { - query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). - ResourceOwner(rm.ResourceOwner). - AllowTimeTravel(). - AddQuery(). - InstanceID(rm.InstanceID). - AggregateTypes(quota.AggregateType). - EventTypes( - quota.AddedEventType, - quota.RemovedEventType, - ).EventData(map[string]interface{}{"unit": rm.unit}) - - return query.Builder() -} - -func (rm *quotaReadModel) Reduce() error { - for _, event := range rm.Events { - switch e := event.(type) { - case *quota.AddedEvent: - rm.AggregateID = e.Aggregate().ID - rm.active = true - rm.config = e - case *quota.RemovedEvent: - rm.AggregateID = e.Aggregate().ID - rm.active = false - rm.config = nil - } - } - return rm.ReadModel.Reduce() -} diff --git a/internal/query/quota_notifications.go b/internal/query/quota_notifications.go index 63930dbe1b..7fc0748f63 100644 --- a/internal/query/quota_notifications.go +++ b/internal/query/quota_notifications.go @@ -2,58 +2,180 @@ package query import ( "context" + "database/sql" + errs "errors" "math" "time" - "github.com/zitadel/zitadel/internal/eventstore" + sq "github.com/Masterminds/squirrel" + "github.com/pkg/errors" + + "github.com/zitadel/zitadel/internal/api/call" + zitadel_errors "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/repository/quota" + "github.com/zitadel/zitadel/internal/telemetry/tracing" ) -func (q *Queries) GetDueQuotaNotifications(ctx context.Context, config *quota.AddedEvent, periodStart time.Time, usedAbs uint64) ([]*quota.NotificationDueEvent, error) { - if len(config.Notifications) == 0 { +var ( + quotaNotificationsTable = table{ + name: projection.QuotaNotificationsTable, + instanceIDCol: projection.QuotaNotificationColumnInstanceID, + } + QuotaNotificationColumnInstanceID = Column{ + name: projection.QuotaNotificationColumnInstanceID, + table: quotaNotificationsTable, + } + QuotaNotificationColumnUnit = Column{ + name: projection.QuotaNotificationColumnUnit, + table: quotaNotificationsTable, + } + QuotaNotificationColumnID = Column{ + name: projection.QuotaNotificationColumnID, + table: quotaNotificationsTable, + } + QuotaNotificationColumnCallURL = Column{ + name: projection.QuotaNotificationColumnCallURL, + table: quotaNotificationsTable, + } + QuotaNotificationColumnPercent = Column{ + name: projection.QuotaNotificationColumnPercent, + table: quotaNotificationsTable, + } + QuotaNotificationColumnRepeat = Column{ + name: projection.QuotaNotificationColumnRepeat, + table: quotaNotificationsTable, + } + QuotaNotificationColumnLatestDuePeriodStart = Column{ + name: projection.QuotaNotificationColumnLatestDuePeriodStart, + table: quotaNotificationsTable, + } + QuotaNotificationColumnNextDueThreshold = Column{ + name: projection.QuotaNotificationColumnNextDueThreshold, + table: quotaNotificationsTable, + } +) + +func (q *Queries) GetDueQuotaNotifications(ctx context.Context, instanceID string, unit quota.Unit, qu *Quota, periodStart time.Time, usedAbs uint64) (dueNotifications []*quota.NotificationDueEvent, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + usedRel := uint16(math.Floor(float64(usedAbs*100) / float64(qu.Amount))) + query, scan := prepareQuotaNotificationsQuery(ctx, q.client) + stmt, args, err := query.Where( + sq.And{ + sq.Eq{ + QuotaNotificationColumnInstanceID.identifier(): instanceID, + QuotaNotificationColumnUnit.identifier(): unit, + }, + sq.Or{ + // If the relative usage is greater than the next due threshold in the current period, it's clear we can notify + sq.And{ + sq.Eq{QuotaNotificationColumnLatestDuePeriodStart.identifier(): periodStart}, + sq.LtOrEq{QuotaNotificationColumnNextDueThreshold.identifier(): usedRel}, + }, + // In case we haven't seen a due notification for this quota period, we compare against the configured percent + sq.And{ + sq.Or{ + sq.Expr(QuotaNotificationColumnLatestDuePeriodStart.identifier() + " IS NULL"), + sq.NotEq{QuotaNotificationColumnLatestDuePeriodStart.identifier(): periodStart}, + }, + sq.LtOrEq{QuotaNotificationColumnPercent.identifier(): usedRel}, + }, + }, + }, + ).ToSql() + if err != nil { + return nil, zitadel_errors.ThrowInternal(err, "QUERY-XmYn9", "Errors.Query.SQLStatement") + } + var notifications *QuotaNotifications + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + notifications, err = scan(rows) + return err + }, stmt, args...) + if errors.Is(err, sql.ErrNoRows) { return nil, nil } - - aggregate := config.Aggregate() - wm, err := q.getQuotaNotificationsReadModel(ctx, aggregate, periodStart) if err != nil { return nil, err } - - usedRel := uint16(math.Floor(float64(usedAbs*100) / float64(config.Amount))) - - var dueNotifications []*quota.NotificationDueEvent - for _, notification := range config.Notifications { - if notification.Percent > usedRel { + for _, notification := range notifications.Configs { + reachedThreshold := calculateThreshold(usedRel, notification.Percent) + if !notification.Repeat && notification.Percent < reachedThreshold { continue } - - threshold := notification.Percent - if notification.Repeat { - threshold = uint16(math.Max(1, math.Floor(float64(usedRel)/float64(notification.Percent)))) * notification.Percent - } - - if wm.latestDueThresholds[notification.ID] < threshold { - dueNotifications = append( - dueNotifications, - quota.NewNotificationDueEvent( - ctx, - &aggregate, - config.Unit, - notification.ID, - notification.CallURL, - periodStart, - threshold, - usedAbs, - ), - ) - } + dueNotifications = append( + dueNotifications, + quota.NewNotificationDueEvent( + ctx, + "a.NewAggregate(qu.ID, instanceID).Aggregate, + unit, + notification.ID, + notification.CallURL, + periodStart, + reachedThreshold, + usedAbs, + ), + ) } - return dueNotifications, nil } -func (q *Queries) getQuotaNotificationsReadModel(ctx context.Context, aggregate eventstore.Aggregate, periodStart time.Time) (*quotaNotificationsReadModel, error) { - wm := newQuotaNotificationsReadModel(aggregate.ID, aggregate.InstanceID, aggregate.ResourceOwner, periodStart) - return wm, q.eventstore.FilterToQueryReducer(ctx, wm) +type QuotaNotification struct { + ID string + CallURL string + Percent uint16 + Repeat bool + NextDueThreshold uint16 +} + +type QuotaNotifications struct { + SearchResponse + Configs []*QuotaNotification +} + +// calculateThreshold calculates the nearest reached threshold. +// It makes sure that the percent configured on the notification is calculated within the "current" 100%, +// e.g. when configuring 80%, the thresholds are 80, 180, 280, ... +// so 170% use is always 70% of the current bucket, with the above config, the reached threshold would be 80. +func calculateThreshold(usedRel, notificationPercent uint16) uint16 { + // check how many times we reached 100% + times := math.Floor(float64(usedRel) / 100) + // check how many times we reached the percent configured with the "current" 100% + percent := math.Floor(float64(usedRel%100) / float64(notificationPercent)) + // If neither is reached, directly return 0. + // This way we don't end up in some wrong uint16 range in the calculation below. + if times == 0 && percent == 0 { + return 0 + } + return uint16(times+percent-1)*100 + notificationPercent +} + +func prepareQuotaNotificationsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*QuotaNotifications, error)) { + return sq.Select( + QuotaNotificationColumnID.identifier(), + QuotaNotificationColumnCallURL.identifier(), + QuotaNotificationColumnPercent.identifier(), + QuotaNotificationColumnRepeat.identifier(), + QuotaNotificationColumnNextDueThreshold.identifier(), + ). + From(quotaNotificationsTable.identifier() + db.Timetravel(call.Took(ctx))). + PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*QuotaNotifications, error) { + cfgs := &QuotaNotifications{Configs: []*QuotaNotification{}} + for rows.Next() { + cfg := new(QuotaNotification) + var nextDueThreshold sql.NullInt16 + err := rows.Scan(&cfg.ID, &cfg.CallURL, &cfg.Percent, &cfg.Repeat, &nextDueThreshold) + if err != nil { + if errs.Is(err, sql.ErrNoRows) { + return nil, zitadel_errors.ThrowNotFound(err, "QUERY-bbqWb", "Errors.QuotaNotification.NotExisting") + } + return nil, zitadel_errors.ThrowInternal(err, "QUERY-8copS", "Errors.Internal") + } + if nextDueThreshold.Valid { + cfg.NextDueThreshold = uint16(nextDueThreshold.Int16) + } + cfgs.Configs = append(cfgs.Configs, cfg) + } + return cfgs, nil + } } diff --git a/internal/query/quota_notifications_model.go b/internal/query/quota_notifications_model.go deleted file mode 100644 index af809c7047..0000000000 --- a/internal/query/quota_notifications_model.go +++ /dev/null @@ -1,46 +0,0 @@ -package query - -import ( - "time" - - "github.com/zitadel/zitadel/internal/eventstore" - "github.com/zitadel/zitadel/internal/repository/quota" -) - -type quotaNotificationsReadModel struct { - eventstore.ReadModel - periodStart time.Time - latestDueThresholds map[string]uint16 -} - -func newQuotaNotificationsReadModel(aggregateId, instanceId, resourceOwner string, periodStart time.Time) *quotaNotificationsReadModel { - return "aNotificationsReadModel{ - ReadModel: eventstore.ReadModel{ - AggregateID: aggregateId, - InstanceID: instanceId, - ResourceOwner: resourceOwner, - }, - periodStart: periodStart, - latestDueThresholds: make(map[string]uint16), - } -} - -func (rm *quotaNotificationsReadModel) Query() *eventstore.SearchQueryBuilder { - return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). - ResourceOwner(rm.ResourceOwner). - AllowTimeTravel(). - AddQuery(). - InstanceID(rm.InstanceID). - AggregateTypes(quota.AggregateType). - AggregateIDs(rm.AggregateID). - CreationDateAfter(rm.periodStart). - EventTypes(quota.NotificationDueEventType).Builder() -} - -func (rm *quotaNotificationsReadModel) Reduce() error { - for _, event := range rm.Events { - e := event.(*quota.NotificationDueEvent) - rm.latestDueThresholds[e.ID] = e.Threshold - } - return rm.ReadModel.Reduce() -} diff --git a/internal/query/quota_notifications_test.go b/internal/query/quota_notifications_test.go new file mode 100644 index 0000000000..a86b31df57 --- /dev/null +++ b/internal/query/quota_notifications_test.go @@ -0,0 +1,181 @@ +package query + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "regexp" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_calculateThreshold(t *testing.T) { + type args struct { + usedRel uint16 + notificationPercent uint16 + } + tests := []struct { + name string + args args + want uint16 + }{ + { + name: "80 - below configuration", + args: args{ + usedRel: 70, + notificationPercent: 80, + }, + want: 0, + }, + { + name: "80 - below 100 percent use", + args: args{ + usedRel: 90, + notificationPercent: 80, + }, + want: 80, + }, + { + name: "80 - above 100 percent use", + args: args{ + usedRel: 120, + notificationPercent: 80, + }, + want: 80, + }, + { + name: "80 - more than twice the use", + args: args{ + usedRel: 190, + notificationPercent: 80, + }, + want: 180, + }, + { + name: "100 - below 100 percent use", + args: args{ + usedRel: 90, + notificationPercent: 100, + }, + want: 0, + }, + { + name: "100 - above 100 percent use", + args: args{ + usedRel: 120, + notificationPercent: 100, + }, + want: 100, + }, + { + name: "100 - more than twice the use", + args: args{ + usedRel: 210, + notificationPercent: 100, + }, + want: 200, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := calculateThreshold(tt.args.usedRel, tt.args.notificationPercent) + assert.Equal(t, int(tt.want), int(got)) + }) + } +} + +var ( + expectedQuotaNotificationsQuery = regexp.QuoteMeta(`SELECT projections.quotas_notifications.id,` + + ` projections.quotas_notifications.call_url,` + + ` projections.quotas_notifications.percent,` + + ` projections.quotas_notifications.repeat,` + + ` projections.quotas_notifications.next_due_threshold` + + ` FROM projections.quotas_notifications` + + ` AS OF SYSTEM TIME '-1 ms'`) + + quotaNotificationsCols = []string{ + "id", + "call_url", + "percent", + "repeat", + "next_due_threshold", + } +) + +func Test_prepareQuotaNotificationsQuery(t *testing.T) { + type want struct { + sqlExpectations sqlExpectation + err checkErr + } + tests := []struct { + name string + prepare interface{} + want want + object interface{} + }{ + { + name: "prepareQuotaNotificationsQuery no result", + prepare: prepareQuotaNotificationsQuery, + want: want{ + sqlExpectations: mockQueries( + expectedQuotaNotificationsQuery, + nil, + nil, + ), + }, + object: &QuotaNotifications{Configs: []*QuotaNotification{}}, + }, + { + name: "prepareQuotaNotificationsQuery", + prepare: prepareQuotaNotificationsQuery, + want: want{ + sqlExpectations: mockQuery( + expectedQuotaNotificationsQuery, + quotaNotificationsCols, + []driver.Value{ + "quota-id", + "url", + uint16(100), + true, + uint16(100), + }, + ), + }, + object: &QuotaNotifications{ + Configs: []*QuotaNotification{ + { + ID: "quota-id", + CallURL: "url", + Percent: 100, + Repeat: true, + NextDueThreshold: 100, + }, + }, + }, + }, + { + name: "prepareQuotaNotificationsQuery sql err", + prepare: prepareQuotaNotificationsQuery, + want: want{ + sqlExpectations: mockQueryErr( + expectedQuotaNotificationsQuery, + sql.ErrConnDone, + ), + err: func(err error) (error, bool) { + if !errors.Is(err, sql.ErrConnDone) { + return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false + } + return nil, true + }, + }, + object: (*Quota)(nil), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + }) + } +} diff --git a/internal/query/quota_period.go b/internal/query/quota_period.go deleted file mode 100644 index b18a5f0427..0000000000 --- a/internal/query/quota_period.go +++ /dev/null @@ -1,30 +0,0 @@ -package query - -import ( - "context" - "time" - - "github.com/zitadel/zitadel/internal/repository/quota" -) - -func (q *Queries) GetCurrentQuotaPeriod(ctx context.Context, instanceID string, unit quota.Unit) (*quota.AddedEvent, time.Time, error) { - rm, err := q.getQuotaReadModel(ctx, instanceID, instanceID, unit) - if err != nil || !rm.active { - return nil, time.Time{}, err - } - - return rm.config, pushPeriodStart(rm.config.From, rm.config.ResetInterval, time.Now()), nil -} - -func pushPeriodStart(from time.Time, interval time.Duration, now time.Time) time.Time { - next := from.Add(interval) - if next.After(now) { - return from - } - return pushPeriodStart(next, interval, now) -} - -func (q *Queries) getQuotaReadModel(ctx context.Context, instanceId, resourceOwner string, unit quota.Unit) (*quotaReadModel, error) { - rm := newQuotaReadModel(instanceId, resourceOwner, unit) - return rm, q.eventstore.FilterToQueryReducer(ctx, rm) -} diff --git a/internal/query/quota_periods.go b/internal/query/quota_periods.go new file mode 100644 index 0000000000..5f540e9fcc --- /dev/null +++ b/internal/query/quota_periods.go @@ -0,0 +1,86 @@ +package query + +import ( + "context" + "database/sql" + "errors" + + sq "github.com/Masterminds/squirrel" + + "github.com/zitadel/zitadel/internal/api/call" + zitadel_errors "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/query/projection" + "github.com/zitadel/zitadel/internal/repository/quota" + "github.com/zitadel/zitadel/internal/telemetry/tracing" +) + +var ( + quotaPeriodsTable = table{ + name: projection.QuotaPeriodsProjectionTable, + instanceIDCol: projection.QuotaColumnInstanceID, + } + QuotaPeriodColumnInstanceID = Column{ + name: projection.QuotaPeriodColumnInstanceID, + table: quotaPeriodsTable, + } + QuotaPeriodColumnUnit = Column{ + name: projection.QuotaPeriodColumnUnit, + table: quotaPeriodsTable, + } + QuotaPeriodColumnStart = Column{ + name: projection.QuotaPeriodColumnStart, + table: quotaPeriodsTable, + } + QuotaPeriodColumnUsage = Column{ + name: projection.QuotaPeriodColumnUsage, + table: quotaPeriodsTable, + } +) + +func (q *Queries) GetRemainingQuotaUsage(ctx context.Context, instanceID string, unit quota.Unit) (remaining *uint64, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + stmt, scan := prepareRemainingQuotaUsageQuery(ctx, q.client) + query, args, err := stmt.Where( + sq.And{ + sq.Eq{ + QuotaPeriodColumnInstanceID.identifier(): instanceID, + QuotaPeriodColumnUnit.identifier(): unit, + QuotaColumnLimit.identifier(): true, + }, + sq.Expr("age(" + QuotaPeriodColumnStart.identifier() + ") < " + QuotaColumnInterval.identifier()), + sq.Expr(QuotaPeriodColumnStart.identifier() + " < now()"), + }). + ToSql() + if err != nil { + return nil, zitadel_errors.ThrowInternal(err, "QUERY-FSA3g", "Errors.Query.SQLStatement") + } + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + remaining, err = scan(row) + return err + }, query, args...) + if zitadel_errors.IsNotFound(err) { + return nil, nil + } + return remaining, err +} + +func prepareRemainingQuotaUsageQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*uint64, error)) { + return sq. + Select( + "greatest(0, " + QuotaColumnAmount.identifier() + "-" + QuotaPeriodColumnUsage.identifier() + ")", + ). + From(quotaPeriodsTable.identifier()). + Join(join(QuotaColumnUnit, QuotaPeriodColumnUnit) + db.Timetravel(call.Took(ctx))). + PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*uint64, error) { + usage := new(uint64) + err := row.Scan(usage) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, zitadel_errors.ThrowNotFound(err, "QUERY-quiowi2", "Errors.Internal") + } + return nil, zitadel_errors.ThrowInternal(err, "QUERY-81j1jn2", "Errors.Internal") + } + return usage, nil + } +} diff --git a/internal/query/quota_periods_test.go b/internal/query/quota_periods_test.go new file mode 100644 index 0000000000..25ea2056e9 --- /dev/null +++ b/internal/query/quota_periods_test.go @@ -0,0 +1,95 @@ +package query + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "regexp" + "testing" + + errs "github.com/zitadel/zitadel/internal/errors" +) + +var ( + expectedRemainingQuotaUsageQuery = regexp.QuoteMeta(`SELECT greatest(0, projections.quotas.amount-projections.quotas_periods.usage)` + + ` FROM projections.quotas_periods` + + ` JOIN projections.quotas ON projections.quotas_periods.unit = projections.quotas.unit AND projections.quotas_periods.instance_id = projections.quotas.instance_id` + + ` AS OF SYSTEM TIME '-1 ms'`) + remainingQuotaUsageCols = []string{ + "usage", + } +) + +func Test_prepareRemainingQuotaUsageQuery(t *testing.T) { + type want struct { + sqlExpectations sqlExpectation + err checkErr + } + tests := []struct { + name string + prepare interface{} + want want + object interface{} + }{ + { + name: "prepareRemainingQuotaUsageQuery no result", + prepare: prepareRemainingQuotaUsageQuery, + want: want{ + sqlExpectations: mockQueryScanErr( + expectedRemainingQuotaUsageQuery, + nil, + nil, + ), + err: func(err error) (error, bool) { + if !errs.IsNotFound(err) { + return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false + } + return nil, true + }, + }, + object: (*uint64)(nil), + }, + { + name: "prepareRemainingQuotaUsageQuery", + prepare: prepareRemainingQuotaUsageQuery, + want: want{ + sqlExpectations: mockQuery( + expectedRemainingQuotaUsageQuery, + remainingQuotaUsageCols, + []driver.Value{ + uint64(100), + }, + ), + }, + object: uint64P(100), + }, + { + name: "prepareRemainingQuotaUsageQuery sql err", + prepare: prepareRemainingQuotaUsageQuery, + want: want{ + sqlExpectations: mockQueryErr( + expectedRemainingQuotaUsageQuery, + sql.ErrConnDone, + ), + err: func(err error) (error, bool) { + if !errors.Is(err, sql.ErrConnDone) { + return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false + } + return nil, true + }, + }, + object: (*uint64)(nil), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + }) + } +} + +func uint64P(i int) *uint64 { + u := uint64(i) + return &u +} diff --git a/internal/query/quota_test.go b/internal/query/quota_test.go new file mode 100644 index 0000000000..05bddc8031 --- /dev/null +++ b/internal/query/quota_test.go @@ -0,0 +1,127 @@ +package query + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "regexp" + "testing" + "time" + + "github.com/jackc/pgtype" + + errs "github.com/zitadel/zitadel/internal/errors" +) + +var ( + expectedQuotaQuery = regexp.QuoteMeta(`SELECT projections.quotas.id,` + + ` projections.quotas.from_anchor,` + + ` projections.quotas.interval,` + + ` projections.quotas.amount,` + + ` projections.quotas.limit_usage,` + + ` now()` + + ` FROM projections.quotas`) + + quotaCols = []string{ + "id", + "from_anchor", + "interval", + "amount", + "limit_usage", + "now", + } +) + +func dayNow() time.Time { + return time.Now().Truncate(24 * time.Hour) +} + +func interval(t *testing.T, src time.Duration) pgtype.Interval { + interval := pgtype.Interval{} + err := interval.Set(src) + if err != nil { + t.Fatal(err) + } + return interval +} + +func Test_QuotaPrepare(t *testing.T) { + type want struct { + sqlExpectations sqlExpectation + err checkErr + } + tests := []struct { + name string + prepare interface{} + want want + object interface{} + }{ + { + name: "prepareQuotaQuery no result", + prepare: prepareQuotaQuery, + want: want{ + sqlExpectations: mockQueriesScanErr( + expectedQuotaQuery, + nil, + nil, + ), + err: func(err error) (error, bool) { + if !errs.IsNotFound(err) { + return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false + } + return nil, true + }, + }, + object: (*Quota)(nil), + }, + { + name: "prepareQuotaQuery", + prepare: prepareQuotaQuery, + want: want{ + sqlExpectations: mockQuery( + expectedQuotaQuery, + quotaCols, + []driver.Value{ + "quota-id", + dayNow(), + interval(t, time.Hour*24), + uint64(1000), + true, + testNow, + }, + ), + }, + object: &Quota{ + ID: "quota-id", + From: dayNow(), + ResetInterval: time.Hour * 24, + CurrentPeriodStart: dayNow(), + Amount: 1000, + Limit: true, + }, + }, + { + name: "prepareQuotaQuery sql err", + prepare: prepareQuotaQuery, + want: want{ + sqlExpectations: mockQueryErr( + expectedQuotaQuery, + sql.ErrConnDone, + ), + err: func(err error) (error, bool) { + if !errors.Is(err, sql.ErrConnDone) { + return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false + } + return nil, true + }, + }, + object: (*Quota)(nil), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + }) + } +} diff --git a/internal/repository/quota/aggregate.go b/internal/repository/quota/aggregate.go index 0e17faed78..c1194dc346 100644 --- a/internal/repository/quota/aggregate.go +++ b/internal/repository/quota/aggregate.go @@ -13,14 +13,14 @@ type Aggregate struct { eventstore.Aggregate } -func NewAggregate(id, instanceId, resourceOwner string) *Aggregate { +func NewAggregate(id, instanceId string) *Aggregate { return &Aggregate{ Aggregate: eventstore.Aggregate{ Type: AggregateType, Version: AggregateVersion, ID: id, InstanceID: instanceId, - ResourceOwner: resourceOwner, + ResourceOwner: instanceId, }, } } diff --git a/internal/repository/quota/events.go b/internal/repository/quota/events.go index 4d635e2bd3..5bdb8f0c44 100644 --- a/internal/repository/quota/events.go +++ b/internal/repository/quota/events.go @@ -14,13 +14,12 @@ import ( type Unit uint const ( - UniqueQuotaNameType = "quota_units" - UniqueQuotaNotificationIDType = "quota_notification" - eventTypePrefix = eventstore.EventType("quota.") - AddedEventType = eventTypePrefix + "added" - NotifiedEventType = eventTypePrefix + "notified" - NotificationDueEventType = eventTypePrefix + "notificationdue" - RemovedEventType = eventTypePrefix + "removed" + UniqueQuotaNameType = "quota_units" + eventTypePrefix = eventstore.EventType("quota.") + AddedEventType = eventTypePrefix + "added" + NotifiedEventType = eventTypePrefix + "notified" + NotificationDueEventType = eventTypePrefix + "notificationdue" + RemovedEventType = eventTypePrefix + "removed" ) const (