perf: project quotas and usages (#6441)

* project quota added

* project quota removed

* add periods table

* make log record generic

* accumulate usage

* query usage

* count action run seconds

* fix filter in ReportQuotaUsage

* fix existing tests

* fix logstore tests

* fix typo

* fix: add quota unit tests command side

* fix: add quota unit tests command side

* fix: add quota unit tests command side

* move notifications into debouncer and improve limit querying

* cleanup

* comment

* fix: add quota unit tests command side

* fix remaining quota usage query

* implement InmemLogStorage

* cleanup and linting

* improve test

* fix: add quota unit tests command side

* fix: add quota unit tests command side

* fix: add quota unit tests command side

* fix: add quota unit tests command side

* action notifications and fixes for notifications query

* revert console prefix

* fix: add quota unit tests command side

* fix: add quota integration tests

* improve accountable requests

* improve accountable requests

* fix: add quota integration tests

* fix: add quota integration tests

* fix: add quota integration tests

* comment

* remove ability to store logs in db and other changes requested from review

* changes requested from review

* changes requested from review

* Update internal/api/http/middleware/access_interceptor.go

Co-authored-by: Silvan <silvan.reusser@gmail.com>

* tests: fix quotas integration tests

* improve incrementUsageStatement

* linting

* fix: delete e2e tests as intergation tests cover functionality

* Update internal/api/http/middleware/access_interceptor.go

Co-authored-by: Silvan <silvan.reusser@gmail.com>

* backup

* fix conflict

* create rc

* create prerelease

* remove issue release labeling

* fix tracing

---------

Co-authored-by: Livio Spring <livio.a@gmail.com>
Co-authored-by: Stefan Benz <stefan@caos.ch>
Co-authored-by: adlerhurst <silvan.reusser@gmail.com>
This commit is contained in:
Elio Bischof
2023-09-15 16:58:45 +02:00
committed by GitHub
parent b4d0d2c9a7
commit 1a49b7d298
66 changed files with 3423 additions and 1413 deletions

View File

@@ -13,6 +13,7 @@ import (
key_repo "github.com/zitadel/zitadel/internal/repository/keypair"
"github.com/zitadel/zitadel/internal/repository/org"
proj_repo "github.com/zitadel/zitadel/internal/repository/project"
quota_repo "github.com/zitadel/zitadel/internal/repository/quota"
usr_repo "github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/repository/usergrant"
)
@@ -29,6 +30,7 @@ func eventstoreExpect(t *testing.T, expects ...expect) *eventstore.Eventstore {
org.RegisterEventMappers(es)
usr_repo.RegisterEventMappers(es)
proj_repo.RegisterEventMappers(es)
quota_repo.RegisterEventMappers(es)
usergrant.RegisterEventMappers(es)
key_repo.RegisterEventMappers(es)
action_repo.RegisterEventMappers(es)

View File

@@ -69,6 +69,7 @@ var (
SessionProjection *sessionProjection
AuthRequestProjection *authRequestProjection
MilestoneProjection *milestoneProjection
QuotaProjection *quotaProjection
)
type projection interface {
@@ -148,6 +149,7 @@ func Create(ctx context.Context, sqlClient *database.DB, es *eventstore.Eventsto
SessionProjection = newSessionProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["sessions"]))
AuthRequestProjection = newAuthRequestProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["auth_requests"]))
MilestoneProjection = newMilestoneProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["milestones"]))
QuotaProjection = newQuotaProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["quotas"]))
newProjectionsList()
return nil
}
@@ -247,5 +249,6 @@ func newProjectionsList() {
SessionProjection,
AuthRequestProjection,
MilestoneProjection,
QuotaProjection,
}
}

View File

@@ -0,0 +1,285 @@
package projection
import (
"context"
"time"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler"
"github.com/zitadel/zitadel/internal/eventstore/handler/crdb"
"github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/quota"
)
const (
QuotasProjectionTable = "projections.quotas"
QuotaPeriodsProjectionTable = QuotasProjectionTable + "_" + quotaPeriodsTableSuffix
QuotaNotificationsTable = QuotasProjectionTable + "_" + quotaNotificationsTableSuffix
QuotaColumnID = "id"
QuotaColumnInstanceID = "instance_id"
QuotaColumnUnit = "unit"
QuotaColumnAmount = "amount"
QuotaColumnFrom = "from_anchor"
QuotaColumnInterval = "interval"
QuotaColumnLimit = "limit_usage"
quotaPeriodsTableSuffix = "periods"
QuotaPeriodColumnInstanceID = "instance_id"
QuotaPeriodColumnUnit = "unit"
QuotaPeriodColumnStart = "start"
QuotaPeriodColumnUsage = "usage"
quotaNotificationsTableSuffix = "notifications"
QuotaNotificationColumnInstanceID = "instance_id"
QuotaNotificationColumnUnit = "unit"
QuotaNotificationColumnID = "id"
QuotaNotificationColumnCallURL = "call_url"
QuotaNotificationColumnPercent = "percent"
QuotaNotificationColumnRepeat = "repeat"
QuotaNotificationColumnLatestDuePeriodStart = "latest_due_period_start"
QuotaNotificationColumnNextDueThreshold = "next_due_threshold"
)
const (
incrementQuotaStatement = `INSERT INTO projections.quotas_periods` +
` (instance_id, unit, start, usage)` +
` VALUES ($1, $2, $3, $4) ON CONFLICT (instance_id, unit, start)` +
` DO UPDATE SET usage = projections.quotas_periods.usage + excluded.usage RETURNING usage`
)
type quotaProjection struct {
crdb.StatementHandler
client *database.DB
}
func newQuotaProjection(ctx context.Context, config crdb.StatementHandlerConfig) *quotaProjection {
p := new(quotaProjection)
config.ProjectionName = QuotasProjectionTable
config.Reducers = p.reducers()
config.InitCheck = crdb.NewMultiTableCheck(
crdb.NewTable(
[]*crdb.Column{
crdb.NewColumn(QuotaColumnID, crdb.ColumnTypeText),
crdb.NewColumn(QuotaColumnInstanceID, crdb.ColumnTypeText),
crdb.NewColumn(QuotaColumnUnit, crdb.ColumnTypeEnum),
crdb.NewColumn(QuotaColumnAmount, crdb.ColumnTypeInt64),
crdb.NewColumn(QuotaColumnFrom, crdb.ColumnTypeTimestamp),
crdb.NewColumn(QuotaColumnInterval, crdb.ColumnTypeInterval),
crdb.NewColumn(QuotaColumnLimit, crdb.ColumnTypeBool),
},
crdb.NewPrimaryKey(QuotaColumnInstanceID, QuotaColumnUnit),
),
crdb.NewSuffixedTable(
[]*crdb.Column{
crdb.NewColumn(QuotaPeriodColumnInstanceID, crdb.ColumnTypeText),
crdb.NewColumn(QuotaPeriodColumnUnit, crdb.ColumnTypeEnum),
crdb.NewColumn(QuotaPeriodColumnStart, crdb.ColumnTypeTimestamp),
crdb.NewColumn(QuotaPeriodColumnUsage, crdb.ColumnTypeInt64),
},
crdb.NewPrimaryKey(QuotaPeriodColumnInstanceID, QuotaPeriodColumnUnit, QuotaPeriodColumnStart),
quotaPeriodsTableSuffix,
),
crdb.NewSuffixedTable(
[]*crdb.Column{
crdb.NewColumn(QuotaNotificationColumnInstanceID, crdb.ColumnTypeText),
crdb.NewColumn(QuotaNotificationColumnUnit, crdb.ColumnTypeEnum),
crdb.NewColumn(QuotaNotificationColumnID, crdb.ColumnTypeText),
crdb.NewColumn(QuotaNotificationColumnCallURL, crdb.ColumnTypeText),
crdb.NewColumn(QuotaNotificationColumnPercent, crdb.ColumnTypeInt64),
crdb.NewColumn(QuotaNotificationColumnRepeat, crdb.ColumnTypeBool),
crdb.NewColumn(QuotaNotificationColumnLatestDuePeriodStart, crdb.ColumnTypeTimestamp, crdb.Nullable()),
crdb.NewColumn(QuotaNotificationColumnNextDueThreshold, crdb.ColumnTypeInt64, crdb.Nullable()),
},
crdb.NewPrimaryKey(QuotaNotificationColumnInstanceID, QuotaNotificationColumnUnit, QuotaNotificationColumnID),
quotaNotificationsTableSuffix,
),
)
p.StatementHandler = crdb.NewStatementHandler(ctx, config)
p.client = config.Client
return p
}
func (q *quotaProjection) reducers() []handler.AggregateReducer {
return []handler.AggregateReducer{
{
Aggregate: instance.AggregateType,
EventRedusers: []handler.EventReducer{
{
Event: instance.InstanceRemovedEventType,
Reduce: q.reduceInstanceRemoved,
},
},
},
{
Aggregate: quota.AggregateType,
EventRedusers: []handler.EventReducer{
{
Event: quota.AddedEventType,
Reduce: q.reduceQuotaAdded,
},
},
},
{
Aggregate: quota.AggregateType,
EventRedusers: []handler.EventReducer{
{
Event: quota.RemovedEventType,
Reduce: q.reduceQuotaRemoved,
},
},
},
{
Aggregate: quota.AggregateType,
EventRedusers: []handler.EventReducer{
{
Event: quota.NotificationDueEventType,
Reduce: q.reduceQuotaNotificationDue,
},
},
},
{
Aggregate: quota.AggregateType,
EventRedusers: []handler.EventReducer{
{
Event: quota.NotifiedEventType,
Reduce: q.reduceQuotaNotified,
},
},
},
}
}
func (q *quotaProjection) reduceQuotaNotified(event eventstore.Event) (*handler.Statement, error) {
return crdb.NewNoOpStatement(event), nil
}
func (q *quotaProjection) reduceQuotaAdded(event eventstore.Event) (*handler.Statement, error) {
e, err := assertEvent[*quota.AddedEvent](event)
if err != nil {
return nil, err
}
createStatements := make([]func(e eventstore.Event) crdb.Exec, len(e.Notifications)+1)
createStatements[0] = crdb.AddCreateStatement(
[]handler.Column{
handler.NewCol(QuotaColumnID, e.Aggregate().ID),
handler.NewCol(QuotaColumnInstanceID, e.Aggregate().InstanceID),
handler.NewCol(QuotaColumnUnit, e.Unit),
handler.NewCol(QuotaColumnAmount, e.Amount),
handler.NewCol(QuotaColumnFrom, e.From),
handler.NewCol(QuotaColumnInterval, e.ResetInterval),
handler.NewCol(QuotaColumnLimit, e.Limit),
})
for i := range e.Notifications {
notification := e.Notifications[i]
createStatements[i+1] = crdb.AddCreateStatement(
[]handler.Column{
handler.NewCol(QuotaNotificationColumnInstanceID, e.Aggregate().InstanceID),
handler.NewCol(QuotaNotificationColumnUnit, e.Unit),
handler.NewCol(QuotaNotificationColumnID, notification.ID),
handler.NewCol(QuotaNotificationColumnCallURL, notification.CallURL),
handler.NewCol(QuotaNotificationColumnPercent, notification.Percent),
handler.NewCol(QuotaNotificationColumnRepeat, notification.Repeat),
},
crdb.WithTableSuffix(quotaNotificationsTableSuffix),
)
}
return crdb.NewMultiStatement(e, createStatements...), nil
}
func (q *quotaProjection) reduceQuotaNotificationDue(event eventstore.Event) (*handler.Statement, error) {
e, err := assertEvent[*quota.NotificationDueEvent](event)
if err != nil {
return nil, err
}
return crdb.NewUpdateStatement(e,
[]handler.Column{
handler.NewCol(QuotaNotificationColumnLatestDuePeriodStart, e.PeriodStart),
handler.NewCol(QuotaNotificationColumnNextDueThreshold, e.Threshold+100), // next due_threshold is always the reached + 100 => percent (e.g. 90) in the next bucket (e.g. 190)
},
[]handler.Condition{
handler.NewCond(QuotaNotificationColumnInstanceID, e.Aggregate().InstanceID),
handler.NewCond(QuotaNotificationColumnUnit, e.Unit),
handler.NewCond(QuotaNotificationColumnID, e.ID),
},
crdb.WithTableSuffix(quotaNotificationsTableSuffix),
), nil
}
func (q *quotaProjection) reduceQuotaRemoved(event eventstore.Event) (*handler.Statement, error) {
e, err := assertEvent[*quota.RemovedEvent](event)
if err != nil {
return nil, err
}
return crdb.NewMultiStatement(
e,
crdb.AddDeleteStatement(
[]handler.Condition{
handler.NewCond(QuotaPeriodColumnInstanceID, e.Aggregate().InstanceID),
handler.NewCond(QuotaPeriodColumnUnit, e.Unit),
},
crdb.WithTableSuffix(quotaPeriodsTableSuffix),
),
crdb.AddDeleteStatement(
[]handler.Condition{
handler.NewCond(QuotaNotificationColumnInstanceID, e.Aggregate().InstanceID),
handler.NewCond(QuotaNotificationColumnUnit, e.Unit),
},
crdb.WithTableSuffix(quotaNotificationsTableSuffix),
),
crdb.AddDeleteStatement(
[]handler.Condition{
handler.NewCond(QuotaColumnInstanceID, e.Aggregate().InstanceID),
handler.NewCond(QuotaColumnUnit, e.Unit),
},
),
), nil
}
func (q *quotaProjection) reduceInstanceRemoved(event eventstore.Event) (*handler.Statement, error) {
// we only assert the event to make sure it is the correct type
e, err := assertEvent[*instance.InstanceRemovedEvent](event)
if err != nil {
return nil, err
}
return crdb.NewMultiStatement(
e,
crdb.AddDeleteStatement(
[]handler.Condition{
handler.NewCond(QuotaPeriodColumnInstanceID, e.Aggregate().InstanceID),
},
crdb.WithTableSuffix(quotaPeriodsTableSuffix),
),
crdb.AddDeleteStatement(
[]handler.Condition{
handler.NewCond(QuotaNotificationColumnInstanceID, e.Aggregate().InstanceID),
},
crdb.WithTableSuffix(quotaNotificationsTableSuffix),
),
crdb.AddDeleteStatement(
[]handler.Condition{
handler.NewCond(QuotaColumnInstanceID, e.Aggregate().InstanceID),
},
),
), nil
}
func (q *quotaProjection) IncrementUsage(ctx context.Context, unit quota.Unit, instanceID string, periodStart time.Time, count uint64) (sum uint64, err error) {
if count == 0 {
return 0, nil
}
err = q.client.DB.QueryRowContext(
ctx,
incrementQuotaStatement,
instanceID, unit, periodStart, count,
).Scan(&sum)
if err != nil {
return 0, errors.ThrowInternalf(err, "PROJ-SJL3h", "incrementing usage for unit %d failed for at least one quota period", unit)
}
return sum, err
}

View File

@@ -0,0 +1,321 @@
package projection
import (
"context"
"regexp"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler"
"github.com/zitadel/zitadel/internal/eventstore/repository"
"github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/quota"
)
func TestQuotasProjection_reduces(t *testing.T) {
type args struct {
event func(t *testing.T) eventstore.Event
}
tests := []struct {
name string
args args
reduce func(event eventstore.Event) (*handler.Statement, error)
want wantReduce
}{
{
name: "reduceQuotaAdded",
args: args{
event: getEvent(testEvent(
repository.EventType(quota.AddedEventType),
quota.AggregateType,
[]byte(`{
"unit": 1,
"amount": 10,
"limit": true,
"from": "2023-01-01T00:00:00Z",
"interval": 300000000000
}`),
), quota.AddedEventMapper),
},
reduce: (&quotaProjection{}).reduceQuotaAdded,
want: wantReduce{
aggregateType: eventstore.AggregateType("quota"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "INSERT INTO projections.quotas (id, instance_id, unit, amount, from_anchor, interval, limit_usage) VALUES ($1, $2, $3, $4, $5, $6, $7)",
expectedArgs: []interface{}{
"agg-id",
"instance-id",
quota.RequestsAllAuthenticated,
uint64(10),
time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC),
time.Minute * 5,
true,
},
},
},
},
},
},
{
name: "reduceQuotaAdded with notification",
args: args{
event: getEvent(testEvent(
repository.EventType(quota.AddedEventType),
quota.AggregateType,
[]byte(`{
"unit": 1,
"amount": 10,
"limit": true,
"from": "2023-01-01T00:00:00Z",
"interval": 300000000000,
"notifications": [
{
"id": "id",
"percent": 100,
"repeat": true,
"callURL": "url"
}
]
}`),
), quota.AddedEventMapper),
},
reduce: (&quotaProjection{}).reduceQuotaAdded,
want: wantReduce{
aggregateType: eventstore.AggregateType("quota"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "INSERT INTO projections.quotas (id, instance_id, unit, amount, from_anchor, interval, limit_usage) VALUES ($1, $2, $3, $4, $5, $6, $7)",
expectedArgs: []interface{}{
"agg-id",
"instance-id",
quota.RequestsAllAuthenticated,
uint64(10),
time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC),
time.Minute * 5,
true,
},
},
{
expectedStmt: "INSERT INTO projections.quotas_notifications (instance_id, unit, id, call_url, percent, repeat) VALUES ($1, $2, $3, $4, $5, $6)",
expectedArgs: []interface{}{
"instance-id",
quota.RequestsAllAuthenticated,
"id",
"url",
uint16(100),
true,
},
},
},
},
},
},
{
name: "reduceQuotaNotificationDue",
args: args{
event: getEvent(testEvent(
repository.EventType(quota.NotificationDueEventType),
quota.AggregateType,
[]byte(`{
"id": "id",
"unit": 1,
"callURL": "url",
"periodStart": "2023-01-01T00:00:00Z",
"threshold": 200,
"usage": 100
}`),
), quota.NotificationDueEventMapper),
},
reduce: (&quotaProjection{}).reduceQuotaNotificationDue,
want: wantReduce{
aggregateType: eventstore.AggregateType("quota"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.quotas_notifications SET (latest_due_period_start, next_due_threshold) = ($1, $2) WHERE (instance_id = $3) AND (unit = $4) AND (id = $5)",
expectedArgs: []interface{}{
time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC),
uint16(300),
"instance-id",
quota.RequestsAllAuthenticated,
"id",
},
},
},
},
},
},
{
name: "reduceQuotaRemoved",
args: args{
event: getEvent(testEvent(
repository.EventType(quota.RemovedEventType),
quota.AggregateType,
[]byte(`{
"unit": 1
}`),
), quota.RemovedEventMapper),
},
reduce: (&quotaProjection{}).reduceQuotaRemoved,
want: wantReduce{
aggregateType: eventstore.AggregateType("quota"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "DELETE FROM projections.quotas_periods WHERE (instance_id = $1) AND (unit = $2)",
expectedArgs: []interface{}{
"instance-id",
quota.RequestsAllAuthenticated,
},
},
{
expectedStmt: "DELETE FROM projections.quotas_notifications WHERE (instance_id = $1) AND (unit = $2)",
expectedArgs: []interface{}{
"instance-id",
quota.RequestsAllAuthenticated,
},
},
{
expectedStmt: "DELETE FROM projections.quotas WHERE (instance_id = $1) AND (unit = $2)",
expectedArgs: []interface{}{
"instance-id",
quota.RequestsAllAuthenticated,
},
},
},
},
},
}, {
name: "reduceInstanceRemoved",
args: args{
event: getEvent(testEvent(
repository.EventType(instance.InstanceRemovedEventType),
instance.AggregateType,
[]byte(`{
"name": "name"
}`),
), instance.InstanceRemovedEventMapper),
},
reduce: (&quotaProjection{}).reduceInstanceRemoved,
want: wantReduce{
aggregateType: eventstore.AggregateType("instance"),
sequence: 15,
previousSequence: 10,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "DELETE FROM projections.quotas_periods WHERE (instance_id = $1)",
expectedArgs: []interface{}{
"instance-id",
},
},
{
expectedStmt: "DELETE FROM projections.quotas_notifications WHERE (instance_id = $1)",
expectedArgs: []interface{}{
"instance-id",
},
},
{
expectedStmt: "DELETE FROM projections.quotas WHERE (instance_id = $1)",
expectedArgs: []interface{}{
"instance-id",
},
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
event := baseEvent(t)
got, err := tt.reduce(event)
if !errors.IsErrorInvalidArgument(err) {
t.Errorf("no wrong event mapping: %v, got: %v", err, got)
}
event = tt.args.event(t)
got, err = tt.reduce(event)
assertReduce(t, got, err, QuotasProjectionTable, tt.want)
})
}
}
func Test_quotaProjection_IncrementUsage(t *testing.T) {
testNow := time.Now()
type fields struct {
client *database.DB
}
type args struct {
ctx context.Context
unit quota.Unit
instanceID string
periodStart time.Time
count uint64
}
type res struct {
sum uint64
err error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
name: "",
fields: fields{
client: func() *database.DB {
db, mock, _ := sqlmock.New()
mock.ExpectQuery(regexp.QuoteMeta(incrementQuotaStatement)).
WithArgs(
"instance_id",
1,
testNow,
2,
).
WillReturnRows(sqlmock.NewRows([]string{"key"}).
AddRow(3))
return &database.DB{DB: db}
}(),
},
args: args{
ctx: context.Background(),
unit: quota.RequestsAllAuthenticated,
instanceID: "instance_id",
periodStart: testNow,
count: 2,
},
res: res{
sum: 3,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := &quotaProjection{
client: tt.fields.client,
}
gotSum, err := q.IncrementUsage(tt.args.ctx, tt.args.unit, tt.args.instanceID, tt.args.periodStart, tt.args.count)
assert.Equal(t, tt.res.sum, gotSum)
assert.ErrorIs(t, err, tt.res.err)
})
}
}

View File

@@ -26,6 +26,7 @@ import (
"github.com/zitadel/zitadel/internal/repository/oidcsession"
"github.com/zitadel/zitadel/internal/repository/org"
"github.com/zitadel/zitadel/internal/repository/project"
"github.com/zitadel/zitadel/internal/repository/quota"
"github.com/zitadel/zitadel/internal/repository/session"
usr_repo "github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/repository/usergrant"
@@ -93,6 +94,7 @@ func StartQueries(
idpintent.RegisterEventMappers(repo.eventstore)
authrequest.RegisterEventMappers(repo.eventstore)
oidcsession.RegisterEventMappers(repo.eventstore)
quota.RegisterEventMappers(repo.eventstore)
repo.idpConfigEncryption = idpConfigEncryption
repo.multifactors = domain.MultifactorConfigs{

121
internal/query/quota.go Normal file
View File

@@ -0,0 +1,121 @@
package query
import (
"context"
"database/sql"
errs "errors"
"time"
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/repository/quota"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
var (
quotasTable = table{
name: projection.QuotasProjectionTable,
instanceIDCol: projection.QuotaColumnInstanceID,
}
QuotaColumnID = Column{
name: projection.QuotaColumnID,
table: quotasTable,
}
QuotaColumnInstanceID = Column{
name: projection.QuotaColumnInstanceID,
table: quotasTable,
}
QuotaColumnUnit = Column{
name: projection.QuotaColumnUnit,
table: quotasTable,
}
QuotaColumnAmount = Column{
name: projection.QuotaColumnAmount,
table: quotasTable,
}
QuotaColumnLimit = Column{
name: projection.QuotaColumnLimit,
table: quotasTable,
}
QuotaColumnInterval = Column{
name: projection.QuotaColumnInterval,
table: quotasTable,
}
QuotaColumnFrom = Column{
name: projection.QuotaColumnFrom,
table: quotasTable,
}
)
type Quota struct {
ID string
From time.Time
ResetInterval time.Duration
Amount uint64
Limit bool
CurrentPeriodStart time.Time
}
func (q *Queries) GetQuota(ctx context.Context, instanceID string, unit quota.Unit) (qu *Quota, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareQuotaQuery(ctx, q.client)
stmt, args, err := query.Where(
sq.Eq{
QuotaColumnInstanceID.identifier(): instanceID,
QuotaColumnUnit.identifier(): unit,
},
).ToSql()
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-XmYn9", "Errors.Query.SQLStatement")
}
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
qu, err = scan(row)
return err
}, stmt, args...)
return qu, err
}
func prepareQuotaQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Quota, error)) {
return sq.
Select(
QuotaColumnID.identifier(),
QuotaColumnFrom.identifier(),
QuotaColumnInterval.identifier(),
QuotaColumnAmount.identifier(),
QuotaColumnLimit.identifier(),
"now()",
).
From(quotasTable.identifier()).
PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*Quota, error) {
q := new(Quota)
var interval database.Duration
var now time.Time
err := row.Scan(&q.ID, &q.From, &interval, &q.Amount, &q.Limit, &now)
if err != nil {
if errs.Is(err, sql.ErrNoRows) {
return nil, errors.ThrowNotFound(err, "QUERY-rDTM6", "Errors.Quota.NotExisting")
}
return nil, errors.ThrowInternal(err, "QUERY-LqySK", "Errors.Internal")
}
q.ResetInterval = time.Duration(interval)
q.CurrentPeriodStart = pushPeriodStart(q.From, q.ResetInterval, now)
return q, nil
}
}
func pushPeriodStart(from time.Time, interval time.Duration, now time.Time) time.Time {
if now.IsZero() {
now = time.Now()
}
for {
next := from.Add(interval)
if next.After(now) {
return from
}
from = next
}
}

View File

@@ -1,55 +0,0 @@
package query
import (
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/quota"
)
type quotaReadModel struct {
eventstore.ReadModel
unit quota.Unit
active bool
config *quota.AddedEvent
}
// newQuotaReadModel aggregateId is filled by reducing unit matching events
func newQuotaReadModel(instanceId, resourceOwner string, unit quota.Unit) *quotaReadModel {
return &quotaReadModel{
ReadModel: eventstore.ReadModel{
InstanceID: instanceId,
ResourceOwner: resourceOwner,
},
unit: unit,
}
}
func (rm *quotaReadModel) Query() *eventstore.SearchQueryBuilder {
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
ResourceOwner(rm.ResourceOwner).
AllowTimeTravel().
AddQuery().
InstanceID(rm.InstanceID).
AggregateTypes(quota.AggregateType).
EventTypes(
quota.AddedEventType,
quota.RemovedEventType,
).EventData(map[string]interface{}{"unit": rm.unit})
return query.Builder()
}
func (rm *quotaReadModel) Reduce() error {
for _, event := range rm.Events {
switch e := event.(type) {
case *quota.AddedEvent:
rm.AggregateID = e.Aggregate().ID
rm.active = true
rm.config = e
case *quota.RemovedEvent:
rm.AggregateID = e.Aggregate().ID
rm.active = false
rm.config = nil
}
}
return rm.ReadModel.Reduce()
}

View File

@@ -2,58 +2,180 @@ package query
import (
"context"
"database/sql"
errs "errors"
"math"
"time"
"github.com/zitadel/zitadel/internal/eventstore"
sq "github.com/Masterminds/squirrel"
"github.com/pkg/errors"
"github.com/zitadel/zitadel/internal/api/call"
zitadel_errors "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/repository/quota"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
func (q *Queries) GetDueQuotaNotifications(ctx context.Context, config *quota.AddedEvent, periodStart time.Time, usedAbs uint64) ([]*quota.NotificationDueEvent, error) {
if len(config.Notifications) == 0 {
var (
quotaNotificationsTable = table{
name: projection.QuotaNotificationsTable,
instanceIDCol: projection.QuotaNotificationColumnInstanceID,
}
QuotaNotificationColumnInstanceID = Column{
name: projection.QuotaNotificationColumnInstanceID,
table: quotaNotificationsTable,
}
QuotaNotificationColumnUnit = Column{
name: projection.QuotaNotificationColumnUnit,
table: quotaNotificationsTable,
}
QuotaNotificationColumnID = Column{
name: projection.QuotaNotificationColumnID,
table: quotaNotificationsTable,
}
QuotaNotificationColumnCallURL = Column{
name: projection.QuotaNotificationColumnCallURL,
table: quotaNotificationsTable,
}
QuotaNotificationColumnPercent = Column{
name: projection.QuotaNotificationColumnPercent,
table: quotaNotificationsTable,
}
QuotaNotificationColumnRepeat = Column{
name: projection.QuotaNotificationColumnRepeat,
table: quotaNotificationsTable,
}
QuotaNotificationColumnLatestDuePeriodStart = Column{
name: projection.QuotaNotificationColumnLatestDuePeriodStart,
table: quotaNotificationsTable,
}
QuotaNotificationColumnNextDueThreshold = Column{
name: projection.QuotaNotificationColumnNextDueThreshold,
table: quotaNotificationsTable,
}
)
func (q *Queries) GetDueQuotaNotifications(ctx context.Context, instanceID string, unit quota.Unit, qu *Quota, periodStart time.Time, usedAbs uint64) (dueNotifications []*quota.NotificationDueEvent, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
usedRel := uint16(math.Floor(float64(usedAbs*100) / float64(qu.Amount)))
query, scan := prepareQuotaNotificationsQuery(ctx, q.client)
stmt, args, err := query.Where(
sq.And{
sq.Eq{
QuotaNotificationColumnInstanceID.identifier(): instanceID,
QuotaNotificationColumnUnit.identifier(): unit,
},
sq.Or{
// If the relative usage is greater than the next due threshold in the current period, it's clear we can notify
sq.And{
sq.Eq{QuotaNotificationColumnLatestDuePeriodStart.identifier(): periodStart},
sq.LtOrEq{QuotaNotificationColumnNextDueThreshold.identifier(): usedRel},
},
// In case we haven't seen a due notification for this quota period, we compare against the configured percent
sq.And{
sq.Or{
sq.Expr(QuotaNotificationColumnLatestDuePeriodStart.identifier() + " IS NULL"),
sq.NotEq{QuotaNotificationColumnLatestDuePeriodStart.identifier(): periodStart},
},
sq.LtOrEq{QuotaNotificationColumnPercent.identifier(): usedRel},
},
},
},
).ToSql()
if err != nil {
return nil, zitadel_errors.ThrowInternal(err, "QUERY-XmYn9", "Errors.Query.SQLStatement")
}
var notifications *QuotaNotifications
err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
notifications, err = scan(rows)
return err
}, stmt, args...)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
aggregate := config.Aggregate()
wm, err := q.getQuotaNotificationsReadModel(ctx, aggregate, periodStart)
if err != nil {
return nil, err
}
usedRel := uint16(math.Floor(float64(usedAbs*100) / float64(config.Amount)))
var dueNotifications []*quota.NotificationDueEvent
for _, notification := range config.Notifications {
if notification.Percent > usedRel {
for _, notification := range notifications.Configs {
reachedThreshold := calculateThreshold(usedRel, notification.Percent)
if !notification.Repeat && notification.Percent < reachedThreshold {
continue
}
threshold := notification.Percent
if notification.Repeat {
threshold = uint16(math.Max(1, math.Floor(float64(usedRel)/float64(notification.Percent)))) * notification.Percent
}
if wm.latestDueThresholds[notification.ID] < threshold {
dueNotifications = append(
dueNotifications,
quota.NewNotificationDueEvent(
ctx,
&aggregate,
config.Unit,
notification.ID,
notification.CallURL,
periodStart,
threshold,
usedAbs,
),
)
}
dueNotifications = append(
dueNotifications,
quota.NewNotificationDueEvent(
ctx,
&quota.NewAggregate(qu.ID, instanceID).Aggregate,
unit,
notification.ID,
notification.CallURL,
periodStart,
reachedThreshold,
usedAbs,
),
)
}
return dueNotifications, nil
}
func (q *Queries) getQuotaNotificationsReadModel(ctx context.Context, aggregate eventstore.Aggregate, periodStart time.Time) (*quotaNotificationsReadModel, error) {
wm := newQuotaNotificationsReadModel(aggregate.ID, aggregate.InstanceID, aggregate.ResourceOwner, periodStart)
return wm, q.eventstore.FilterToQueryReducer(ctx, wm)
type QuotaNotification struct {
ID string
CallURL string
Percent uint16
Repeat bool
NextDueThreshold uint16
}
type QuotaNotifications struct {
SearchResponse
Configs []*QuotaNotification
}
// calculateThreshold calculates the nearest reached threshold.
// It makes sure that the percent configured on the notification is calculated within the "current" 100%,
// e.g. when configuring 80%, the thresholds are 80, 180, 280, ...
// so 170% use is always 70% of the current bucket, with the above config, the reached threshold would be 80.
func calculateThreshold(usedRel, notificationPercent uint16) uint16 {
// check how many times we reached 100%
times := math.Floor(float64(usedRel) / 100)
// check how many times we reached the percent configured with the "current" 100%
percent := math.Floor(float64(usedRel%100) / float64(notificationPercent))
// If neither is reached, directly return 0.
// This way we don't end up in some wrong uint16 range in the calculation below.
if times == 0 && percent == 0 {
return 0
}
return uint16(times+percent-1)*100 + notificationPercent
}
func prepareQuotaNotificationsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*QuotaNotifications, error)) {
return sq.Select(
QuotaNotificationColumnID.identifier(),
QuotaNotificationColumnCallURL.identifier(),
QuotaNotificationColumnPercent.identifier(),
QuotaNotificationColumnRepeat.identifier(),
QuotaNotificationColumnNextDueThreshold.identifier(),
).
From(quotaNotificationsTable.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*QuotaNotifications, error) {
cfgs := &QuotaNotifications{Configs: []*QuotaNotification{}}
for rows.Next() {
cfg := new(QuotaNotification)
var nextDueThreshold sql.NullInt16
err := rows.Scan(&cfg.ID, &cfg.CallURL, &cfg.Percent, &cfg.Repeat, &nextDueThreshold)
if err != nil {
if errs.Is(err, sql.ErrNoRows) {
return nil, zitadel_errors.ThrowNotFound(err, "QUERY-bbqWb", "Errors.QuotaNotification.NotExisting")
}
return nil, zitadel_errors.ThrowInternal(err, "QUERY-8copS", "Errors.Internal")
}
if nextDueThreshold.Valid {
cfg.NextDueThreshold = uint16(nextDueThreshold.Int16)
}
cfgs.Configs = append(cfgs.Configs, cfg)
}
return cfgs, nil
}
}

View File

@@ -1,46 +0,0 @@
package query
import (
"time"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/quota"
)
type quotaNotificationsReadModel struct {
eventstore.ReadModel
periodStart time.Time
latestDueThresholds map[string]uint16
}
func newQuotaNotificationsReadModel(aggregateId, instanceId, resourceOwner string, periodStart time.Time) *quotaNotificationsReadModel {
return &quotaNotificationsReadModel{
ReadModel: eventstore.ReadModel{
AggregateID: aggregateId,
InstanceID: instanceId,
ResourceOwner: resourceOwner,
},
periodStart: periodStart,
latestDueThresholds: make(map[string]uint16),
}
}
func (rm *quotaNotificationsReadModel) Query() *eventstore.SearchQueryBuilder {
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
ResourceOwner(rm.ResourceOwner).
AllowTimeTravel().
AddQuery().
InstanceID(rm.InstanceID).
AggregateTypes(quota.AggregateType).
AggregateIDs(rm.AggregateID).
CreationDateAfter(rm.periodStart).
EventTypes(quota.NotificationDueEventType).Builder()
}
func (rm *quotaNotificationsReadModel) Reduce() error {
for _, event := range rm.Events {
e := event.(*quota.NotificationDueEvent)
rm.latestDueThresholds[e.ID] = e.Threshold
}
return rm.ReadModel.Reduce()
}

View File

@@ -0,0 +1,181 @@
package query
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"regexp"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_calculateThreshold(t *testing.T) {
type args struct {
usedRel uint16
notificationPercent uint16
}
tests := []struct {
name string
args args
want uint16
}{
{
name: "80 - below configuration",
args: args{
usedRel: 70,
notificationPercent: 80,
},
want: 0,
},
{
name: "80 - below 100 percent use",
args: args{
usedRel: 90,
notificationPercent: 80,
},
want: 80,
},
{
name: "80 - above 100 percent use",
args: args{
usedRel: 120,
notificationPercent: 80,
},
want: 80,
},
{
name: "80 - more than twice the use",
args: args{
usedRel: 190,
notificationPercent: 80,
},
want: 180,
},
{
name: "100 - below 100 percent use",
args: args{
usedRel: 90,
notificationPercent: 100,
},
want: 0,
},
{
name: "100 - above 100 percent use",
args: args{
usedRel: 120,
notificationPercent: 100,
},
want: 100,
},
{
name: "100 - more than twice the use",
args: args{
usedRel: 210,
notificationPercent: 100,
},
want: 200,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := calculateThreshold(tt.args.usedRel, tt.args.notificationPercent)
assert.Equal(t, int(tt.want), int(got))
})
}
}
var (
expectedQuotaNotificationsQuery = regexp.QuoteMeta(`SELECT projections.quotas_notifications.id,` +
` projections.quotas_notifications.call_url,` +
` projections.quotas_notifications.percent,` +
` projections.quotas_notifications.repeat,` +
` projections.quotas_notifications.next_due_threshold` +
` FROM projections.quotas_notifications` +
` AS OF SYSTEM TIME '-1 ms'`)
quotaNotificationsCols = []string{
"id",
"call_url",
"percent",
"repeat",
"next_due_threshold",
}
)
func Test_prepareQuotaNotificationsQuery(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
err checkErr
}
tests := []struct {
name string
prepare interface{}
want want
object interface{}
}{
{
name: "prepareQuotaNotificationsQuery no result",
prepare: prepareQuotaNotificationsQuery,
want: want{
sqlExpectations: mockQueries(
expectedQuotaNotificationsQuery,
nil,
nil,
),
},
object: &QuotaNotifications{Configs: []*QuotaNotification{}},
},
{
name: "prepareQuotaNotificationsQuery",
prepare: prepareQuotaNotificationsQuery,
want: want{
sqlExpectations: mockQuery(
expectedQuotaNotificationsQuery,
quotaNotificationsCols,
[]driver.Value{
"quota-id",
"url",
uint16(100),
true,
uint16(100),
},
),
},
object: &QuotaNotifications{
Configs: []*QuotaNotification{
{
ID: "quota-id",
CallURL: "url",
Percent: 100,
Repeat: true,
NextDueThreshold: 100,
},
},
},
},
{
name: "prepareQuotaNotificationsQuery sql err",
prepare: prepareQuotaNotificationsQuery,
want: want{
sqlExpectations: mockQueryErr(
expectedQuotaNotificationsQuery,
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
if !errors.Is(err, sql.ErrConnDone) {
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
}
return nil, true
},
},
object: (*Quota)(nil),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}

View File

@@ -1,30 +0,0 @@
package query
import (
"context"
"time"
"github.com/zitadel/zitadel/internal/repository/quota"
)
func (q *Queries) GetCurrentQuotaPeriod(ctx context.Context, instanceID string, unit quota.Unit) (*quota.AddedEvent, time.Time, error) {
rm, err := q.getQuotaReadModel(ctx, instanceID, instanceID, unit)
if err != nil || !rm.active {
return nil, time.Time{}, err
}
return rm.config, pushPeriodStart(rm.config.From, rm.config.ResetInterval, time.Now()), nil
}
func pushPeriodStart(from time.Time, interval time.Duration, now time.Time) time.Time {
next := from.Add(interval)
if next.After(now) {
return from
}
return pushPeriodStart(next, interval, now)
}
func (q *Queries) getQuotaReadModel(ctx context.Context, instanceId, resourceOwner string, unit quota.Unit) (*quotaReadModel, error) {
rm := newQuotaReadModel(instanceId, resourceOwner, unit)
return rm, q.eventstore.FilterToQueryReducer(ctx, rm)
}

View File

@@ -0,0 +1,86 @@
package query
import (
"context"
"database/sql"
"errors"
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/api/call"
zitadel_errors "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/repository/quota"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
var (
quotaPeriodsTable = table{
name: projection.QuotaPeriodsProjectionTable,
instanceIDCol: projection.QuotaColumnInstanceID,
}
QuotaPeriodColumnInstanceID = Column{
name: projection.QuotaPeriodColumnInstanceID,
table: quotaPeriodsTable,
}
QuotaPeriodColumnUnit = Column{
name: projection.QuotaPeriodColumnUnit,
table: quotaPeriodsTable,
}
QuotaPeriodColumnStart = Column{
name: projection.QuotaPeriodColumnStart,
table: quotaPeriodsTable,
}
QuotaPeriodColumnUsage = Column{
name: projection.QuotaPeriodColumnUsage,
table: quotaPeriodsTable,
}
)
func (q *Queries) GetRemainingQuotaUsage(ctx context.Context, instanceID string, unit quota.Unit) (remaining *uint64, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareRemainingQuotaUsageQuery(ctx, q.client)
query, args, err := stmt.Where(
sq.And{
sq.Eq{
QuotaPeriodColumnInstanceID.identifier(): instanceID,
QuotaPeriodColumnUnit.identifier(): unit,
QuotaColumnLimit.identifier(): true,
},
sq.Expr("age(" + QuotaPeriodColumnStart.identifier() + ") < " + QuotaColumnInterval.identifier()),
sq.Expr(QuotaPeriodColumnStart.identifier() + " < now()"),
}).
ToSql()
if err != nil {
return nil, zitadel_errors.ThrowInternal(err, "QUERY-FSA3g", "Errors.Query.SQLStatement")
}
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
remaining, err = scan(row)
return err
}, query, args...)
if zitadel_errors.IsNotFound(err) {
return nil, nil
}
return remaining, err
}
func prepareRemainingQuotaUsageQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*uint64, error)) {
return sq.
Select(
"greatest(0, " + QuotaColumnAmount.identifier() + "-" + QuotaPeriodColumnUsage.identifier() + ")",
).
From(quotaPeriodsTable.identifier()).
Join(join(QuotaColumnUnit, QuotaPeriodColumnUnit) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*uint64, error) {
usage := new(uint64)
err := row.Scan(usage)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, zitadel_errors.ThrowNotFound(err, "QUERY-quiowi2", "Errors.Internal")
}
return nil, zitadel_errors.ThrowInternal(err, "QUERY-81j1jn2", "Errors.Internal")
}
return usage, nil
}
}

View File

@@ -0,0 +1,95 @@
package query
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"regexp"
"testing"
errs "github.com/zitadel/zitadel/internal/errors"
)
var (
expectedRemainingQuotaUsageQuery = regexp.QuoteMeta(`SELECT greatest(0, projections.quotas.amount-projections.quotas_periods.usage)` +
` FROM projections.quotas_periods` +
` JOIN projections.quotas ON projections.quotas_periods.unit = projections.quotas.unit AND projections.quotas_periods.instance_id = projections.quotas.instance_id` +
` AS OF SYSTEM TIME '-1 ms'`)
remainingQuotaUsageCols = []string{
"usage",
}
)
func Test_prepareRemainingQuotaUsageQuery(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
err checkErr
}
tests := []struct {
name string
prepare interface{}
want want
object interface{}
}{
{
name: "prepareRemainingQuotaUsageQuery no result",
prepare: prepareRemainingQuotaUsageQuery,
want: want{
sqlExpectations: mockQueryScanErr(
expectedRemainingQuotaUsageQuery,
nil,
nil,
),
err: func(err error) (error, bool) {
if !errs.IsNotFound(err) {
return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false
}
return nil, true
},
},
object: (*uint64)(nil),
},
{
name: "prepareRemainingQuotaUsageQuery",
prepare: prepareRemainingQuotaUsageQuery,
want: want{
sqlExpectations: mockQuery(
expectedRemainingQuotaUsageQuery,
remainingQuotaUsageCols,
[]driver.Value{
uint64(100),
},
),
},
object: uint64P(100),
},
{
name: "prepareRemainingQuotaUsageQuery sql err",
prepare: prepareRemainingQuotaUsageQuery,
want: want{
sqlExpectations: mockQueryErr(
expectedRemainingQuotaUsageQuery,
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
if !errors.Is(err, sql.ErrConnDone) {
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
}
return nil, true
},
},
object: (*uint64)(nil),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}
func uint64P(i int) *uint64 {
u := uint64(i)
return &u
}

View File

@@ -0,0 +1,127 @@
package query
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"regexp"
"testing"
"time"
"github.com/jackc/pgtype"
errs "github.com/zitadel/zitadel/internal/errors"
)
var (
expectedQuotaQuery = regexp.QuoteMeta(`SELECT projections.quotas.id,` +
` projections.quotas.from_anchor,` +
` projections.quotas.interval,` +
` projections.quotas.amount,` +
` projections.quotas.limit_usage,` +
` now()` +
` FROM projections.quotas`)
quotaCols = []string{
"id",
"from_anchor",
"interval",
"amount",
"limit_usage",
"now",
}
)
func dayNow() time.Time {
return time.Now().Truncate(24 * time.Hour)
}
func interval(t *testing.T, src time.Duration) pgtype.Interval {
interval := pgtype.Interval{}
err := interval.Set(src)
if err != nil {
t.Fatal(err)
}
return interval
}
func Test_QuotaPrepare(t *testing.T) {
type want struct {
sqlExpectations sqlExpectation
err checkErr
}
tests := []struct {
name string
prepare interface{}
want want
object interface{}
}{
{
name: "prepareQuotaQuery no result",
prepare: prepareQuotaQuery,
want: want{
sqlExpectations: mockQueriesScanErr(
expectedQuotaQuery,
nil,
nil,
),
err: func(err error) (error, bool) {
if !errs.IsNotFound(err) {
return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false
}
return nil, true
},
},
object: (*Quota)(nil),
},
{
name: "prepareQuotaQuery",
prepare: prepareQuotaQuery,
want: want{
sqlExpectations: mockQuery(
expectedQuotaQuery,
quotaCols,
[]driver.Value{
"quota-id",
dayNow(),
interval(t, time.Hour*24),
uint64(1000),
true,
testNow,
},
),
},
object: &Quota{
ID: "quota-id",
From: dayNow(),
ResetInterval: time.Hour * 24,
CurrentPeriodStart: dayNow(),
Amount: 1000,
Limit: true,
},
},
{
name: "prepareQuotaQuery sql err",
prepare: prepareQuotaQuery,
want: want{
sqlExpectations: mockQueryErr(
expectedQuotaQuery,
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
if !errors.Is(err, sql.ErrConnDone) {
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
}
return nil, true
},
},
object: (*Quota)(nil),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
}