package access import ( "context" "database/sql" "errors" "time" "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/logstore" "github.com/zitadel/zitadel/internal/logstore/record" "github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/repository/quota" "github.com/zitadel/zitadel/internal/telemetry/tracing" ) var _ logstore.UsageStorer[*record.AccessLog] = (*databaseLogStorage)(nil) type databaseLogStorage struct { dbClient *database.DB commands *command.Commands queries *query.Queries } func NewDatabaseLogStorage(dbClient *database.DB, commands *command.Commands, queries *query.Queries) *databaseLogStorage { return &databaseLogStorage{dbClient: dbClient, commands: commands, queries: queries} } func (l *databaseLogStorage) QuotaUnit() quota.Unit { return quota.RequestsAllAuthenticated } func (l *databaseLogStorage) Emit(ctx context.Context, bulk []*record.AccessLog) error { if len(bulk) == 0 { return nil } return l.incrementUsage(ctx, bulk) } func (l *databaseLogStorage) incrementUsage(ctx context.Context, bulk []*record.AccessLog) (err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() byInstance := make(map[string][]*record.AccessLog) for _, r := range bulk { if r.InstanceID != "" { byInstance[r.InstanceID] = append(byInstance[r.InstanceID], r) } } for instanceID, instanceBulk := range byInstance { q, getQuotaErr := l.queries.GetQuota(ctx, instanceID, quota.RequestsAllAuthenticated) if errors.Is(getQuotaErr, sql.ErrNoRows) { continue } err = errors.Join(err, getQuotaErr) if getQuotaErr != nil { continue } sum, incrementErr := l.incrementUsageFromAccessLogs(ctx, instanceID, q.CurrentPeriodStart, instanceBulk) err = errors.Join(err, incrementErr) if incrementErr != nil { continue } notifications, getNotificationErr := l.queries.GetDueQuotaNotifications(ctx, instanceID, quota.RequestsAllAuthenticated, q, q.CurrentPeriodStart, sum) err = errors.Join(err, getNotificationErr) if getNotificationErr != nil || len(notifications) == 0 { continue } ctx = authz.WithInstanceID(ctx, instanceID) reportErr := l.commands.ReportQuotaUsage(ctx, notifications) err = errors.Join(err, reportErr) if reportErr != nil { continue } } return err } func (l *databaseLogStorage) incrementUsageFromAccessLogs(ctx context.Context, instanceID string, periodStart time.Time, records []*record.AccessLog) (sum uint64, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() var count uint64 for _, r := range records { if r.IsAuthenticated() { count++ } } return projection.QuotaProjection.IncrementUsage(ctx, quota.RequestsAllAuthenticated, instanceID, periodStart, count) }