perf(milestones): refactor (#8788)
Some checks are pending
ZITADEL CI/CD / core (push) Waiting to run
ZITADEL CI/CD / console (push) Waiting to run
ZITADEL CI/CD / version (push) Waiting to run
ZITADEL CI/CD / compile (push) Blocked by required conditions
ZITADEL CI/CD / core-unit-test (push) Blocked by required conditions
ZITADEL CI/CD / core-integration-test (push) Blocked by required conditions
ZITADEL CI/CD / lint (push) Blocked by required conditions
ZITADEL CI/CD / container (push) Blocked by required conditions
ZITADEL CI/CD / e2e (push) Blocked by required conditions
ZITADEL CI/CD / release (push) Blocked by required conditions
Code Scanning / CodeQL-Build (go) (push) Waiting to run
Code Scanning / CodeQL-Build (javascript) (push) Waiting to run

# Which Problems Are Solved

Milestones used existing events from a number of aggregates. OIDC
session is one of them. We noticed in load-tests that the reduction of
the oidc_session.added event into the milestone projection is a costly
business with payload based conditionals. A milestone is reached once,
but even then we remain subscribed to the OIDC events. This requires the
projections.current_states to be updated continuously.


# How the Problems Are Solved

The milestone creation is refactored to use dedicated events instead.
The command side decides when a milestone is reached and creates the
reached event once for each milestone when required.

# Additional Changes

In order to prevent reached milestones being created twice, a migration
script is provided. When the old `projections.milestones` table exist,
the state is read from there and `v2` milestone aggregate events are
created, with the original reached and pushed dates.

# Additional Context

- Closes https://github.com/zitadel/zitadel/issues/8800
This commit is contained in:
Tim Möhlmann
2024-10-28 09:29:34 +01:00
committed by GitHub
parent 54f1c0bc50
commit 32bad3feb3
46 changed files with 1612 additions and 756 deletions

View File

@@ -163,6 +163,7 @@ func projections(
} }
commands, err := command.StartCommands( commands, err := command.StartCommands(
es, es,
config.Caches,
config.SystemDefaults, config.SystemDefaults,
config.InternalAuthZ.RolePermissionMappings, config.InternalAuthZ.RolePermissionMappings,
staticStorage, staticStorage,

View File

@@ -65,6 +65,7 @@ func (mig *FirstInstance) Execute(ctx context.Context, _ eventstore.Event) error
} }
cmd, err := command.StartCommands(mig.es, cmd, err := command.StartCommands(mig.es,
nil,
mig.defaults, mig.defaults,
mig.zitadelRoles, mig.zitadelRoles,
nil, nil,

118
cmd/setup/36.go Normal file
View File

@@ -0,0 +1,118 @@
package setup
import (
"context"
_ "embed"
"errors"
"fmt"
"slices"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/milestone"
)
var (
//go:embed 36.sql
getProjectedMilestones string
)
type FillV2Milestones struct {
dbClient *database.DB
eventstore *eventstore.Eventstore
}
type instanceMilestone struct {
Type milestone.Type
Reached time.Time
Pushed *time.Time
}
func (mig *FillV2Milestones) Execute(ctx context.Context, _ eventstore.Event) error {
im, err := mig.getProjectedMilestones(ctx)
if err != nil {
return err
}
return mig.pushEventsByInstance(ctx, im)
}
func (mig *FillV2Milestones) getProjectedMilestones(ctx context.Context) (map[string][]instanceMilestone, error) {
type row struct {
InstanceID string
Type milestone.Type
Reached time.Time
Pushed *time.Time
}
rows, _ := mig.dbClient.Pool.Query(ctx, getProjectedMilestones)
scanned, err := pgx.CollectRows(rows, pgx.RowToStructByPos[row])
var pgError *pgconn.PgError
// catch ERROR: relation "projections.milestones" does not exist
if errors.As(err, &pgError) && pgError.SQLState() == "42P01" {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("milestones get: %w", err)
}
milestoneMap := make(map[string][]instanceMilestone)
for _, s := range scanned {
milestoneMap[s.InstanceID] = append(milestoneMap[s.InstanceID], instanceMilestone{
Type: s.Type,
Reached: s.Reached,
Pushed: s.Pushed,
})
}
return milestoneMap, nil
}
// pushEventsByInstance creates the v2 milestone events by instance.
// This prevents we will try to push 6*N(instance) events in one push.
func (mig *FillV2Milestones) pushEventsByInstance(ctx context.Context, milestoneMap map[string][]instanceMilestone) error {
// keep a deterministic order by instance ID.
order := make([]string, 0, len(milestoneMap))
for k := range milestoneMap {
order = append(order, k)
}
slices.Sort(order)
for _, instanceID := range order {
logging.WithFields("instance_id", instanceID, "migration", mig.String()).Info("filter existing milestone events")
// because each Push runs in a separate TX, we need to make sure that events
// from a partially executed migration are pushed again.
model := command.NewMilestonesReachedWriteModel(instanceID)
if err := mig.eventstore.FilterToQueryReducer(ctx, model); err != nil {
return fmt.Errorf("milestones filter: %w", err)
}
if model.InstanceCreated {
logging.WithFields("instance_id", instanceID, "migration", mig.String()).Info("milestone events already migrated")
continue // This instance was migrated, skip
}
logging.WithFields("instance_id", instanceID, "migration", mig.String()).Info("push milestone events")
aggregate := milestone.NewInstanceAggregate(instanceID)
cmds := make([]eventstore.Command, 0, len(milestoneMap[instanceID])*2)
for _, m := range milestoneMap[instanceID] {
cmds = append(cmds, milestone.NewReachedEventWithDate(ctx, aggregate, m.Type, &m.Reached))
if m.Pushed != nil {
cmds = append(cmds, milestone.NewPushedEventWithDate(ctx, aggregate, m.Type, nil, "", m.Pushed))
}
}
if _, err := mig.eventstore.Push(ctx, cmds...); err != nil {
return fmt.Errorf("milestones push: %w", err)
}
}
return nil
}
func (mig *FillV2Milestones) String() string {
return "36_fill_v2_milestones"
}

4
cmd/setup/36.sql Normal file
View File

@@ -0,0 +1,4 @@
SELECT instance_id, type, reached_date, last_pushed_date
FROM projections.milestones
WHERE reached_date IS NOT NULL
ORDER BY instance_id, reached_date;

View File

@@ -122,6 +122,7 @@ type Steps struct {
s33SMSConfigs3TwilioAddVerifyServiceSid *SMSConfigs3TwilioAddVerifyServiceSid s33SMSConfigs3TwilioAddVerifyServiceSid *SMSConfigs3TwilioAddVerifyServiceSid
s34AddCacheSchema *AddCacheSchema s34AddCacheSchema *AddCacheSchema
s35AddPositionToIndexEsWm *AddPositionToIndexEsWm s35AddPositionToIndexEsWm *AddPositionToIndexEsWm
s36FillV2Milestones *FillV2Milestones
} }
func MustNewSteps(v *viper.Viper) *Steps { func MustNewSteps(v *viper.Viper) *Steps {

View File

@@ -33,6 +33,7 @@ func (mig *externalConfigChange) Check(lastRun map[string]interface{}) bool {
func (mig *externalConfigChange) Execute(ctx context.Context, _ eventstore.Event) error { func (mig *externalConfigChange) Execute(ctx context.Context, _ eventstore.Event) error {
cmd, err := command.StartCommands( cmd, err := command.StartCommands(
mig.es, mig.es,
nil,
mig.defaults, mig.defaults,
nil, nil,
nil, nil,

View File

@@ -165,6 +165,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
steps.s33SMSConfigs3TwilioAddVerifyServiceSid = &SMSConfigs3TwilioAddVerifyServiceSid{dbClient: esPusherDBClient} steps.s33SMSConfigs3TwilioAddVerifyServiceSid = &SMSConfigs3TwilioAddVerifyServiceSid{dbClient: esPusherDBClient}
steps.s34AddCacheSchema = &AddCacheSchema{dbClient: queryDBClient} steps.s34AddCacheSchema = &AddCacheSchema{dbClient: queryDBClient}
steps.s35AddPositionToIndexEsWm = &AddPositionToIndexEsWm{dbClient: esPusherDBClient} steps.s35AddPositionToIndexEsWm = &AddPositionToIndexEsWm{dbClient: esPusherDBClient}
steps.s36FillV2Milestones = &FillV2Milestones{dbClient: queryDBClient, eventstore: eventstoreClient}
err = projection.Create(ctx, projectionDBClient, eventstoreClient, config.Projections, nil, nil, nil) err = projection.Create(ctx, projectionDBClient, eventstoreClient, config.Projections, nil, nil, nil)
logging.OnError(err).Fatal("unable to start projections") logging.OnError(err).Fatal("unable to start projections")
@@ -209,6 +210,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
steps.s30FillFieldsForOrgDomainVerified, steps.s30FillFieldsForOrgDomainVerified,
steps.s34AddCacheSchema, steps.s34AddCacheSchema,
steps.s35AddPositionToIndexEsWm, steps.s35AddPositionToIndexEsWm,
steps.s36FillV2Milestones,
} { } {
mustExecuteMigration(ctx, eventstoreClient, step, "migration failed") mustExecuteMigration(ctx, eventstoreClient, step, "migration failed")
} }
@@ -390,6 +392,7 @@ func initProjections(
} }
commands, err := command.StartCommands( commands, err := command.StartCommands(
eventstoreClient, eventstoreClient,
config.Caches,
config.SystemDefaults, config.SystemDefaults,
config.InternalAuthZ.RolePermissionMappings, config.InternalAuthZ.RolePermissionMappings,
staticStorage, staticStorage,

View File

@@ -224,6 +224,7 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server
} }
commands, err := command.StartCommands( commands, err := command.StartCommands(
eventstoreClient, eventstoreClient,
config.Caches,
config.SystemDefaults, config.SystemDefaults,
config.InternalAuthZ.RolePermissionMappings, config.InternalAuthZ.RolePermissionMappings,
storage, storage,

View File

@@ -114,7 +114,15 @@ func WithConsole(ctx context.Context, projectID, appID string) context.Context {
i.projectID = projectID i.projectID = projectID
i.appID = appID i.appID = appID
//i.clientID = clientID return context.WithValue(ctx, instanceKey, i)
}
func WithConsoleClientID(ctx context.Context, clientID string) context.Context {
i, ok := ctx.Value(instanceKey).(*instance)
if !ok {
i = new(instance)
}
i.clientID = clientID
return context.WithValue(ctx, instanceKey, i) return context.WithValue(ctx, instanceKey, i)
} }

View File

@@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/zitadel/logging" "github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database/postgres" "github.com/zitadel/zitadel/internal/database/postgres"
) )
@@ -77,7 +78,8 @@ type CachesConfig struct {
Postgres PostgresConnectorConfig Postgres PostgresConnectorConfig
// Redis redis.Config? // Redis redis.Config?
} }
Instance *CacheConfig Instance *CacheConfig
Milestones *CacheConfig
} }
type CacheConfig struct { type CacheConfig struct {

82
internal/command/cache.go Normal file
View File

@@ -0,0 +1,82 @@
package command
import (
"context"
"fmt"
"strings"
"github.com/zitadel/zitadel/internal/cache"
"github.com/zitadel/zitadel/internal/cache/gomap"
"github.com/zitadel/zitadel/internal/cache/noop"
"github.com/zitadel/zitadel/internal/cache/pg"
"github.com/zitadel/zitadel/internal/database"
)
type Caches struct {
connectors *cacheConnectors
milestones cache.Cache[milestoneIndex, string, *MilestonesReached]
}
func startCaches(background context.Context, conf *cache.CachesConfig, client *database.DB) (_ *Caches, err error) {
caches := &Caches{
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
}
if conf == nil {
return caches, nil
}
caches.connectors, err = startCacheConnectors(background, conf, client)
if err != nil {
return nil, err
}
caches.milestones, err = startCache[milestoneIndex, string, *MilestonesReached](background, []milestoneIndex{milestoneIndexInstanceID}, "milestones", conf.Instance, caches.connectors)
if err != nil {
return nil, err
}
return caches, nil
}
type cacheConnectors struct {
memory *cache.AutoPruneConfig
postgres *pgxPoolCacheConnector
}
type pgxPoolCacheConnector struct {
*cache.AutoPruneConfig
client *database.DB
}
func startCacheConnectors(_ context.Context, conf *cache.CachesConfig, client *database.DB) (_ *cacheConnectors, err error) {
connectors := new(cacheConnectors)
if conf.Connectors.Memory.Enabled {
connectors.memory = &conf.Connectors.Memory.AutoPrune
}
if conf.Connectors.Postgres.Enabled {
connectors.postgres = &pgxPoolCacheConnector{
AutoPruneConfig: &conf.Connectors.Postgres.AutoPrune,
client: client,
}
}
return connectors, nil
}
func startCache[I ~int, K ~string, V cache.Entry[I, K]](background context.Context, indices []I, name string, conf *cache.CacheConfig, connectors *cacheConnectors) (cache.Cache[I, K, V], error) {
if conf == nil || conf.Connector == "" {
return noop.NewCache[I, K, V](), nil
}
if strings.EqualFold(conf.Connector, "memory") && connectors.memory != nil {
c := gomap.NewCache[I, K, V](background, indices, *conf)
connectors.memory.StartAutoPrune(background, c, name)
return c, nil
}
if strings.EqualFold(conf.Connector, "postgres") && connectors.postgres != nil {
client := connectors.postgres.client
c, err := pg.NewCache[I, K, V](background, name, *conf, indices, client.Pool, client.Type())
if err != nil {
return nil, fmt.Errorf("query start cache: %w", err)
}
connectors.postgres.StartAutoPrune(background, c, name)
return c, nil
}
return nil, fmt.Errorf("cache connector %q not enabled", conf.Connector)
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
api_http "github.com/zitadel/zitadel/internal/api/http" api_http "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/cache"
"github.com/zitadel/zitadel/internal/command/preparation" "github.com/zitadel/zitadel/internal/command/preparation"
sd "github.com/zitadel/zitadel/internal/config/systemdefaults" sd "github.com/zitadel/zitadel/internal/config/systemdefaults"
"github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/crypto"
@@ -88,10 +89,17 @@ type Commands struct {
EventGroupExisting func(group string) bool EventGroupExisting func(group string) bool
GenerateDomain func(instanceName, domain string) (string, error) GenerateDomain func(instanceName, domain string) (string, error)
caches *Caches
// Store instance IDs where all milestones are reached (except InstanceDeleted).
// These instance's milestones never need to be invalidated,
// so the query and cache overhead can completely eliminated.
milestonesCompleted sync.Map
} }
func StartCommands( func StartCommands(
es *eventstore.Eventstore, es *eventstore.Eventstore,
cachesConfig *cache.CachesConfig,
defaults sd.SystemDefaults, defaults sd.SystemDefaults,
zitadelRoles []authz.RoleMapping, zitadelRoles []authz.RoleMapping,
staticStore static.Storage, staticStore static.Storage,
@@ -123,6 +131,10 @@ func StartCommands(
if err != nil { if err != nil {
return nil, fmt.Errorf("password hasher: %w", err) return nil, fmt.Errorf("password hasher: %w", err)
} }
caches, err := startCaches(context.TODO(), cachesConfig, es.Client())
if err != nil {
return nil, fmt.Errorf("caches: %w", err)
}
repo = &Commands{ repo = &Commands{
eventstore: es, eventstore: es,
static: staticStore, static: staticStore,
@@ -176,6 +188,7 @@ func StartCommands(
}, },
}, },
GenerateDomain: domain.NewGeneratedInstanceDomain, GenerateDomain: domain.NewGeneratedInstanceDomain,
caches: caches,
} }
if defaultSecretGenerators != nil && defaultSecretGenerators.ClientSecret != nil { if defaultSecretGenerators != nil && defaultSecretGenerators.ClientSecret != nil {

View File

@@ -6,6 +6,7 @@ import (
"golang.org/x/text/language" "golang.org/x/text/language"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/http" "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/command/preparation" "github.com/zitadel/zitadel/internal/command/preparation"
@@ -17,6 +18,7 @@ import (
"github.com/zitadel/zitadel/internal/notification/channels/smtp" "github.com/zitadel/zitadel/internal/notification/channels/smtp"
"github.com/zitadel/zitadel/internal/repository/instance" "github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/limits" "github.com/zitadel/zitadel/internal/repository/limits"
"github.com/zitadel/zitadel/internal/repository/milestone"
"github.com/zitadel/zitadel/internal/repository/org" "github.com/zitadel/zitadel/internal/repository/org"
"github.com/zitadel/zitadel/internal/repository/project" "github.com/zitadel/zitadel/internal/repository/project"
"github.com/zitadel/zitadel/internal/repository/quota" "github.com/zitadel/zitadel/internal/repository/quota"
@@ -292,7 +294,7 @@ func setUpInstance(ctx context.Context, c *Commands, setup *InstanceSetup) (vali
setupFeatures(&validations, setup.Features, setup.zitadel.instanceID) setupFeatures(&validations, setup.Features, setup.zitadel.instanceID)
setupLimits(c, &validations, limits.NewAggregate(setup.zitadel.limitsID, setup.zitadel.instanceID), setup.Limits) setupLimits(c, &validations, limits.NewAggregate(setup.zitadel.limitsID, setup.zitadel.instanceID), setup.Limits)
setupRestrictions(c, &validations, restrictions.NewAggregate(setup.zitadel.restrictionsID, setup.zitadel.instanceID, setup.zitadel.instanceID), setup.Restrictions) setupRestrictions(c, &validations, restrictions.NewAggregate(setup.zitadel.restrictionsID, setup.zitadel.instanceID, setup.zitadel.instanceID), setup.Restrictions)
setupInstanceCreatedMilestone(&validations, setup.zitadel.instanceID)
return validations, pat, machineKey, nil return validations, pat, machineKey, nil
} }
@@ -890,7 +892,8 @@ func (c *Commands) RemoveInstance(ctx context.Context, id string) (*domain.Objec
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = c.caches.milestones.Invalidate(ctx, milestoneIndexInstanceID, id)
logging.OnError(err).Error("milestone invalidate")
return &domain.ObjectDetails{ return &domain.ObjectDetails{
Sequence: events[len(events)-1].Sequence(), Sequence: events[len(events)-1].Sequence(),
EventDate: events[len(events)-1].CreatedAt(), EventDate: events[len(events)-1].CreatedAt(),
@@ -908,10 +911,16 @@ func (c *Commands) prepareRemoveInstance(a *instance.Aggregate) preparation.Vali
if !writeModel.State.Exists() { if !writeModel.State.Exists() {
return nil, zerrors.ThrowNotFound(err, "COMMA-AE3GS", "Errors.Instance.NotFound") return nil, zerrors.ThrowNotFound(err, "COMMA-AE3GS", "Errors.Instance.NotFound")
} }
return []eventstore.Command{instance.NewInstanceRemovedEvent(ctx, milestoneAggregate := milestone.NewInstanceAggregate(a.ID)
&a.Aggregate, return []eventstore.Command{
writeModel.Name, instance.NewInstanceRemovedEvent(ctx,
writeModel.Domains)}, &a.Aggregate,
writeModel.Name,
writeModel.Domains),
milestone.NewReachedEvent(ctx,
milestoneAggregate,
milestone.InstanceDeleted),
},
nil nil
}, nil }, nil
} }

View File

@@ -13,6 +13,7 @@ import (
"golang.org/x/text/language" "golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/cache/noop"
"github.com/zitadel/zitadel/internal/command/preparation" "github.com/zitadel/zitadel/internal/command/preparation"
"github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
@@ -20,6 +21,7 @@ import (
"github.com/zitadel/zitadel/internal/id" "github.com/zitadel/zitadel/internal/id"
id_mock "github.com/zitadel/zitadel/internal/id/mock" id_mock "github.com/zitadel/zitadel/internal/id/mock"
"github.com/zitadel/zitadel/internal/repository/instance" "github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/milestone"
"github.com/zitadel/zitadel/internal/repository/org" "github.com/zitadel/zitadel/internal/repository/org"
"github.com/zitadel/zitadel/internal/repository/project" "github.com/zitadel/zitadel/internal/repository/project"
"github.com/zitadel/zitadel/internal/repository/user" "github.com/zitadel/zitadel/internal/repository/user"
@@ -372,6 +374,7 @@ func setupInstanceEvents(ctx context.Context, instanceID, orgID, projectID, appI
setupInstanceElementsEvents(ctx, instanceID, instanceName, defaultLanguage), setupInstanceElementsEvents(ctx, instanceID, instanceName, defaultLanguage),
orgEvents(ctx, instanceID, orgID, orgName, projectID, domain, externalSecure, true, true), orgEvents(ctx, instanceID, orgID, orgName, projectID, domain, externalSecure, true, true),
generatedDomainEvents(ctx, instanceID, orgID, projectID, appID, domain), generatedDomainEvents(ctx, instanceID, orgID, projectID, appID, domain),
instanceCreatedMilestoneEvent(ctx, instanceID),
) )
} }
@@ -401,6 +404,12 @@ func generatedDomainEvents(ctx context.Context, instanceID, orgID, projectID, ap
} }
} }
func instanceCreatedMilestoneEvent(ctx context.Context, instanceID string) []eventstore.Command {
return []eventstore.Command{
milestone.NewReachedEvent(ctx, milestone.NewInstanceAggregate(instanceID), milestone.InstanceCreated),
}
}
func generatedDomainFilters(instanceID, orgID, projectID, appID, generatedDomain string) []expect { func generatedDomainFilters(instanceID, orgID, projectID, appID, generatedDomain string) []expect {
return []expect{ return []expect{
expectFilter(), expectFilter(),
@@ -1378,7 +1387,7 @@ func TestCommandSide_UpdateInstance(t *testing.T) {
func TestCommandSide_RemoveInstance(t *testing.T) { func TestCommandSide_RemoveInstance(t *testing.T) {
type fields struct { type fields struct {
eventstore *eventstore.Eventstore eventstore func(t *testing.T) *eventstore.Eventstore
} }
type args struct { type args struct {
ctx context.Context ctx context.Context
@@ -1397,8 +1406,7 @@ func TestCommandSide_RemoveInstance(t *testing.T) {
{ {
name: "instance not existing, not found error", name: "instance not existing, not found error",
fields: fields{ fields: fields{
eventstore: eventstoreExpect( eventstore: expectEventstore(
t,
expectFilter(), expectFilter(),
), ),
}, },
@@ -1413,8 +1421,7 @@ func TestCommandSide_RemoveInstance(t *testing.T) {
{ {
name: "instance removed, not found error", name: "instance removed, not found error",
fields: fields{ fields: fields{
eventstore: eventstoreExpect( eventstore: expectEventstore(
t,
expectFilter( expectFilter(
eventFromEventPusher( eventFromEventPusher(
instance.NewInstanceAddedEvent( instance.NewInstanceAddedEvent(
@@ -1444,8 +1451,7 @@ func TestCommandSide_RemoveInstance(t *testing.T) {
{ {
name: "instance remove, ok", name: "instance remove, ok",
fields: fields{ fields: fields{
eventstore: eventstoreExpect( eventstore: expectEventstore(
t,
expectFilter( expectFilter(
eventFromEventPusherWithInstanceID( eventFromEventPusherWithInstanceID(
"INSTANCE", "INSTANCE",
@@ -1480,6 +1486,10 @@ func TestCommandSide_RemoveInstance(t *testing.T) {
"custom.domain", "custom.domain",
}, },
), ),
milestone.NewReachedEvent(context.Background(),
milestone.NewInstanceAggregate("INSTANCE"),
milestone.InstanceDeleted,
),
), ),
), ),
}, },
@@ -1497,7 +1507,10 @@ func TestCommandSide_RemoveInstance(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
r := &Commands{ r := &Commands{
eventstore: tt.fields.eventstore, eventstore: tt.fields.eventstore(t),
caches: &Caches{
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
},
} }
got, err := r.RemoveInstance(tt.args.ctx, tt.args.instanceID) got, err := r.RemoveInstance(tt.args.ctx, tt.args.instanceID)
if tt.res.err == nil { if tt.res.err == nil {

View File

@@ -3,20 +3,176 @@ package command
import ( import (
"context" "context"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/command/preparation"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/milestone" "github.com/zitadel/zitadel/internal/repository/milestone"
) )
// MilestonePushed writes a new milestone.PushedEvent with a new milestone.Aggregate to the eventstore type milestoneIndex int
const (
milestoneIndexInstanceID milestoneIndex = iota
)
type MilestonesReached struct {
InstanceID string
InstanceCreated bool
AuthenticationSucceededOnInstance bool
ProjectCreated bool
ApplicationCreated bool
AuthenticationSucceededOnApplication bool
InstanceDeleted bool
}
// complete returns true if all milestones except InstanceDeleted are reached.
func (m *MilestonesReached) complete() bool {
return m.InstanceCreated &&
m.AuthenticationSucceededOnInstance &&
m.ProjectCreated &&
m.ApplicationCreated &&
m.AuthenticationSucceededOnApplication
}
// GetMilestonesReached finds the milestone state for the current instance.
func (c *Commands) GetMilestonesReached(ctx context.Context) (*MilestonesReached, error) {
milestones, ok := c.getCachedMilestonesReached(ctx)
if ok {
return milestones, nil
}
model := NewMilestonesReachedWriteModel(authz.GetInstance(ctx).InstanceID())
if err := c.eventstore.FilterToQueryReducer(ctx, model); err != nil {
return nil, err
}
milestones = &model.MilestonesReached
c.setCachedMilestonesReached(ctx, milestones)
return milestones, nil
}
// getCachedMilestonesReached checks for milestone completeness on an instance and returns a filled
// [MilestonesReached] object.
// Otherwise it looks for the object in the milestone cache.
func (c *Commands) getCachedMilestonesReached(ctx context.Context) (*MilestonesReached, bool) {
instanceID := authz.GetInstance(ctx).InstanceID()
if _, ok := c.milestonesCompleted.Load(instanceID); ok {
return &MilestonesReached{
InstanceID: instanceID,
InstanceCreated: true,
AuthenticationSucceededOnInstance: true,
ProjectCreated: true,
ApplicationCreated: true,
AuthenticationSucceededOnApplication: true,
InstanceDeleted: false,
}, ok
}
return c.caches.milestones.Get(ctx, milestoneIndexInstanceID, instanceID)
}
// setCachedMilestonesReached stores the current milestones state in the milestones cache.
// If the milestones are complete, the instance ID is stored in milestonesCompleted instead.
func (c *Commands) setCachedMilestonesReached(ctx context.Context, milestones *MilestonesReached) {
if milestones.complete() {
c.milestonesCompleted.Store(milestones.InstanceID, struct{}{})
return
}
c.caches.milestones.Set(ctx, milestones)
}
// Keys implements cache.Entry
func (c *MilestonesReached) Keys(i milestoneIndex) []string {
if i == milestoneIndexInstanceID {
return []string{c.InstanceID}
}
return nil
}
// MilestonePushed writes a new milestone.PushedEvent with the milestone.Aggregate to the eventstore
func (c *Commands) MilestonePushed( func (c *Commands) MilestonePushed(
ctx context.Context, ctx context.Context,
instanceID string,
msType milestone.Type, msType milestone.Type,
endpoints []string, endpoints []string,
primaryDomain string,
) error { ) error {
id, err := c.idGenerator.Next() _, err := c.eventstore.Push(ctx, milestone.NewPushedEvent(ctx, milestone.NewInstanceAggregate(instanceID), msType, endpoints, c.externalDomain))
if err != nil {
return err
}
_, err = c.eventstore.Push(ctx, milestone.NewPushedEvent(ctx, milestone.NewAggregate(ctx, id), msType, endpoints, primaryDomain, c.externalDomain))
return err return err
} }
func setupInstanceCreatedMilestone(validations *[]preparation.Validation, instanceID string) {
*validations = append(*validations, func() (preparation.CreateCommands, error) {
return func(ctx context.Context, _ preparation.FilterToQueryReducer) ([]eventstore.Command, error) {
return []eventstore.Command{
milestone.NewReachedEvent(ctx, milestone.NewInstanceAggregate(instanceID), milestone.InstanceCreated),
}, nil
}, nil
})
}
func (s *OIDCSessionEvents) SetMilestones(ctx context.Context, clientID string, isHuman bool) (postCommit func(ctx context.Context), err error) {
postCommit = func(ctx context.Context) {}
milestones, err := s.commands.GetMilestonesReached(ctx)
if err != nil {
return postCommit, err
}
instance := authz.GetInstance(ctx)
aggregate := milestone.NewAggregate(ctx)
var invalidate bool
if !milestones.AuthenticationSucceededOnInstance {
s.events = append(s.events, milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance))
invalidate = true
}
if !milestones.AuthenticationSucceededOnApplication && isHuman && clientID != instance.ConsoleClientID() {
s.events = append(s.events, milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication))
invalidate = true
}
if invalidate {
postCommit = s.commands.invalidateMilestoneCachePostCommit(instance.InstanceID())
}
return postCommit, nil
}
func (c *Commands) projectCreatedMilestone(ctx context.Context, cmds *[]eventstore.Command) (postCommit func(ctx context.Context), err error) {
postCommit = func(ctx context.Context) {}
if isSystemUser(ctx) {
return postCommit, nil
}
milestones, err := c.GetMilestonesReached(ctx)
if err != nil {
return postCommit, err
}
if milestones.ProjectCreated {
return postCommit, nil
}
aggregate := milestone.NewAggregate(ctx)
*cmds = append(*cmds, milestone.NewReachedEvent(ctx, aggregate, milestone.ProjectCreated))
return c.invalidateMilestoneCachePostCommit(aggregate.InstanceID), nil
}
func (c *Commands) applicationCreatedMilestone(ctx context.Context, cmds *[]eventstore.Command) (postCommit func(ctx context.Context), err error) {
postCommit = func(ctx context.Context) {}
if isSystemUser(ctx) {
return postCommit, nil
}
milestones, err := c.GetMilestonesReached(ctx)
if err != nil {
return postCommit, err
}
if milestones.ApplicationCreated {
return postCommit, nil
}
aggregate := milestone.NewAggregate(ctx)
*cmds = append(*cmds, milestone.NewReachedEvent(ctx, aggregate, milestone.ApplicationCreated))
return c.invalidateMilestoneCachePostCommit(aggregate.InstanceID), nil
}
func (c *Commands) invalidateMilestoneCachePostCommit(instanceID string) func(ctx context.Context) {
return func(ctx context.Context) {
err := c.caches.milestones.Invalidate(ctx, milestoneIndexInstanceID, instanceID)
logging.WithFields("instance_id", instanceID).OnError(err).Error("failed to invalidate milestone cache")
}
}
func isSystemUser(ctx context.Context) bool {
return authz.GetCtxData(ctx).SystemMemberships != nil
}

View File

@@ -0,0 +1,58 @@
package command
import (
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/milestone"
)
type MilestonesReachedWriteModel struct {
eventstore.WriteModel
MilestonesReached
}
func NewMilestonesReachedWriteModel(instanceID string) *MilestonesReachedWriteModel {
return &MilestonesReachedWriteModel{
WriteModel: eventstore.WriteModel{
AggregateID: instanceID,
InstanceID: instanceID,
},
MilestonesReached: MilestonesReached{
InstanceID: instanceID,
},
}
}
func (m *MilestonesReachedWriteModel) Query() *eventstore.SearchQueryBuilder {
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes(milestone.AggregateType).
AggregateIDs(m.AggregateID).
EventTypes(milestone.ReachedEventType, milestone.PushedEventType).
Builder()
}
func (m *MilestonesReachedWriteModel) Reduce() error {
for _, event := range m.Events {
if e, ok := event.(*milestone.ReachedEvent); ok {
m.reduceReachedEvent(e)
}
}
return m.WriteModel.Reduce()
}
func (m *MilestonesReachedWriteModel) reduceReachedEvent(e *milestone.ReachedEvent) {
switch e.MilestoneType {
case milestone.InstanceCreated:
m.InstanceCreated = true
case milestone.AuthenticationSucceededOnInstance:
m.AuthenticationSucceededOnInstance = true
case milestone.ProjectCreated:
m.ProjectCreated = true
case milestone.ApplicationCreated:
m.ApplicationCreated = true
case milestone.AuthenticationSucceededOnApplication:
m.AuthenticationSucceededOnApplication = true
case milestone.InstanceDeleted:
m.InstanceDeleted = true
}
}

View File

@@ -0,0 +1,629 @@
package command
import (
"context"
"io"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/cache"
"github.com/zitadel/zitadel/internal/cache/gomap"
"github.com/zitadel/zitadel/internal/cache/noop"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/milestone"
)
func TestCommands_GetMilestonesReached(t *testing.T) {
cached := &MilestonesReached{
InstanceID: "cached-id",
InstanceCreated: true,
AuthenticationSucceededOnInstance: true,
}
ctx := authz.WithInstanceID(context.Background(), "instanceID")
aggregate := milestone.NewAggregate(ctx)
type fields struct {
eventstore func(*testing.T) *eventstore.Eventstore
}
type args struct {
ctx context.Context
}
tests := []struct {
name string
fields fields
args args
want *MilestonesReached
wantErr error
}{
{
name: "cached",
fields: fields{
eventstore: expectEventstore(),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "cached-id"),
},
want: cached,
},
{
name: "filter error",
fields: fields{
eventstore: expectEventstore(
expectFilterError(io.ErrClosedPipe),
),
},
args: args{
ctx: ctx,
},
wantErr: io.ErrClosedPipe,
},
{
name: "no events, all false",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
},
},
{
name: "instance created",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.InstanceCreated)),
),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
InstanceCreated: true,
},
},
{
name: "instance auth",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
AuthenticationSucceededOnInstance: true,
},
},
{
name: "project created",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.ProjectCreated)),
),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
ProjectCreated: true,
},
},
{
name: "app created",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.ApplicationCreated)),
),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
ApplicationCreated: true,
},
},
{
name: "app auth",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication)),
),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
AuthenticationSucceededOnApplication: true,
},
},
{
name: "instance deleted",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.InstanceDeleted)),
),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
InstanceDeleted: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cache := gomap.NewCache[milestoneIndex, string, *MilestonesReached](
context.Background(),
[]milestoneIndex{milestoneIndexInstanceID},
cache.CacheConfig{Connector: "memory"},
)
cache.Set(context.Background(), cached)
c := &Commands{
eventstore: tt.fields.eventstore(t),
caches: &Caches{
milestones: cache,
},
}
got, err := c.GetMilestonesReached(tt.args.ctx)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
})
}
}
func TestCommands_milestonesCompleted(t *testing.T) {
c := &Commands{
caches: &Caches{
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
},
}
ctx := authz.WithInstanceID(context.Background(), "instanceID")
arg := &MilestonesReached{
InstanceID: "instanceID",
InstanceCreated: true,
AuthenticationSucceededOnInstance: true,
ProjectCreated: true,
ApplicationCreated: true,
AuthenticationSucceededOnApplication: true,
InstanceDeleted: false,
}
c.setCachedMilestonesReached(ctx, arg)
got, ok := c.getCachedMilestonesReached(ctx)
assert.True(t, ok)
assert.Equal(t, arg, got)
}
func TestCommands_MilestonePushed(t *testing.T) {
aggregate := milestone.NewInstanceAggregate("instanceID")
type fields struct {
eventstore func(*testing.T) *eventstore.Eventstore
}
type args struct {
ctx context.Context
instanceID string
msType milestone.Type
endpoints []string
}
tests := []struct {
name string
fields fields
args args
wantErr error
}{
{
name: "milestone pushed",
fields: fields{
eventstore: expectEventstore(
expectPush(
milestone.NewPushedEvent(
context.Background(),
aggregate,
milestone.ApplicationCreated,
[]string{"foo.com", "bar.com"},
"example.com",
),
),
),
},
args: args{
ctx: context.Background(),
instanceID: "instanceID",
msType: milestone.ApplicationCreated,
endpoints: []string{"foo.com", "bar.com"},
},
wantErr: nil,
},
{
name: "pusher error",
fields: fields{
eventstore: expectEventstore(
expectPushFailed(
io.ErrClosedPipe,
milestone.NewPushedEvent(
context.Background(),
aggregate,
milestone.ApplicationCreated,
[]string{"foo.com", "bar.com"},
"example.com",
),
),
),
},
args: args{
ctx: context.Background(),
instanceID: "instanceID",
msType: milestone.ApplicationCreated,
endpoints: []string{"foo.com", "bar.com"},
},
wantErr: io.ErrClosedPipe,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore(t),
externalDomain: "example.com",
}
err := c.MilestonePushed(tt.args.ctx, tt.args.instanceID, tt.args.msType, tt.args.endpoints)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestOIDCSessionEvents_SetMilestones(t *testing.T) {
ctx := authz.WithInstanceID(context.Background(), "instanceID")
ctx = authz.WithConsoleClientID(ctx, "console")
aggregate := milestone.NewAggregate(ctx)
type fields struct {
eventstore func(*testing.T) *eventstore.Eventstore
}
type args struct {
ctx context.Context
clientID string
isHuman bool
}
tests := []struct {
name string
fields fields
args args
wantEvents []eventstore.Command
wantErr error
}{
{
name: "get error",
fields: fields{
eventstore: expectEventstore(
expectFilterError(io.ErrClosedPipe),
),
},
args: args{
ctx: ctx,
clientID: "client",
isHuman: true,
},
wantErr: io.ErrClosedPipe,
},
{
name: "milestones already reached",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication)),
),
),
},
args: args{
ctx: ctx,
clientID: "client",
isHuman: true,
},
wantErr: nil,
},
{
name: "auth on instance",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
),
},
args: args{
ctx: ctx,
clientID: "console",
isHuman: true,
},
wantEvents: []eventstore.Command{
milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance),
},
wantErr: nil,
},
{
name: "subsequent console login, no milestone",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
),
),
},
args: args{
ctx: ctx,
clientID: "console",
isHuman: true,
},
wantErr: nil,
},
{
name: "subsequent machine login, no milestone",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
),
),
},
args: args{
ctx: ctx,
clientID: "client",
isHuman: false,
},
wantErr: nil,
},
{
name: "auth on app",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
),
),
},
args: args{
ctx: ctx,
clientID: "client",
isHuman: true,
},
wantEvents: []eventstore.Command{
milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication),
},
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore(t),
caches: &Caches{
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
},
}
s := &OIDCSessionEvents{
commands: c,
}
postCommit, err := s.SetMilestones(tt.args.ctx, tt.args.clientID, tt.args.isHuman)
postCommit(tt.args.ctx)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantEvents, s.events)
})
}
}
func TestCommands_projectCreatedMilestone(t *testing.T) {
ctx := authz.WithInstanceID(context.Background(), "instanceID")
systemCtx := authz.SetCtxData(ctx, authz.CtxData{
SystemMemberships: authz.Memberships{
&authz.Membership{
MemberType: authz.MemberTypeSystem,
},
},
})
aggregate := milestone.NewAggregate(ctx)
type fields struct {
eventstore func(*testing.T) *eventstore.Eventstore
}
type args struct {
ctx context.Context
}
tests := []struct {
name string
fields fields
args args
wantEvents []eventstore.Command
wantErr error
}{
{
name: "system user",
fields: fields{
eventstore: expectEventstore(),
},
args: args{
ctx: systemCtx,
},
wantErr: nil,
},
{
name: "get error",
fields: fields{
eventstore: expectEventstore(
expectFilterError(io.ErrClosedPipe),
),
},
args: args{
ctx: ctx,
},
wantErr: io.ErrClosedPipe,
},
{
name: "milestone already reached",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.ProjectCreated)),
),
),
},
args: args{
ctx: ctx,
},
wantErr: nil,
},
{
name: "milestone reached event",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
),
},
args: args{
ctx: ctx,
},
wantEvents: []eventstore.Command{
milestone.NewReachedEvent(ctx, aggregate, milestone.ProjectCreated),
},
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore(t),
caches: &Caches{
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
},
}
var cmds []eventstore.Command
postCommit, err := c.projectCreatedMilestone(tt.args.ctx, &cmds)
postCommit(tt.args.ctx)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantEvents, cmds)
})
}
}
func TestCommands_applicationCreatedMilestone(t *testing.T) {
ctx := authz.WithInstanceID(context.Background(), "instanceID")
systemCtx := authz.SetCtxData(ctx, authz.CtxData{
SystemMemberships: authz.Memberships{
&authz.Membership{
MemberType: authz.MemberTypeSystem,
},
},
})
aggregate := milestone.NewAggregate(ctx)
type fields struct {
eventstore func(*testing.T) *eventstore.Eventstore
}
type args struct {
ctx context.Context
}
tests := []struct {
name string
fields fields
args args
wantEvents []eventstore.Command
wantErr error
}{
{
name: "system user",
fields: fields{
eventstore: expectEventstore(),
},
args: args{
ctx: systemCtx,
},
wantErr: nil,
},
{
name: "get error",
fields: fields{
eventstore: expectEventstore(
expectFilterError(io.ErrClosedPipe),
),
},
args: args{
ctx: ctx,
},
wantErr: io.ErrClosedPipe,
},
{
name: "milestone already reached",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.ApplicationCreated)),
),
),
},
args: args{
ctx: ctx,
},
wantErr: nil,
},
{
name: "milestone reached event",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
),
},
args: args{
ctx: ctx,
},
wantEvents: []eventstore.Command{
milestone.NewReachedEvent(ctx, aggregate, milestone.ApplicationCreated),
},
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore(t),
caches: &Caches{
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
},
}
var cmds []eventstore.Command
postCommit, err := c.applicationCreatedMilestone(tt.args.ctx, &cmds)
postCommit(tt.args.ctx)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantEvents, cmds)
})
}
}
func (c *Commands) setMilestonesCompletedForTest(instanceID string) {
c.milestonesCompleted.Store(instanceID, struct{}{})
}

View File

@@ -71,7 +71,8 @@ func (c *Commands) CreateOIDCSessionFromAuthRequest(ctx context.Context, authReq
return nil, "", zerrors.ThrowPreconditionFailed(nil, "COMMAND-Iung5", "Errors.AuthRequest.NoCode") return nil, "", zerrors.ThrowPreconditionFailed(nil, "COMMAND-Iung5", "Errors.AuthRequest.NoCode")
} }
sessionModel := NewSessionWriteModel(authReqModel.SessionID, authz.GetInstance(ctx).InstanceID()) instanceID := authz.GetInstance(ctx).InstanceID()
sessionModel := NewSessionWriteModel(authReqModel.SessionID, instanceID)
err = c.eventstore.FilterToQueryReducer(ctx, sessionModel) err = c.eventstore.FilterToQueryReducer(ctx, sessionModel)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
@@ -118,8 +119,15 @@ func (c *Commands) CreateOIDCSessionFromAuthRequest(ctx context.Context, authReq
} }
} }
cmd.SetAuthRequestSuccessful(ctx, authReqModel.aggregate) cmd.SetAuthRequestSuccessful(ctx, authReqModel.aggregate)
session, err = cmd.PushEvents(ctx) postCommit, err := cmd.SetMilestones(ctx, authReqModel.ClientID, true)
return session, authReqModel.State, err if err != nil {
return nil, "", err
}
if session, err = cmd.PushEvents(ctx); err != nil {
return nil, "", err
}
postCommit(ctx)
return session, authReqModel.State, nil
} }
func (c *Commands) CreateOIDCSession(ctx context.Context, func (c *Commands) CreateOIDCSession(ctx context.Context,
@@ -161,7 +169,15 @@ func (c *Commands) CreateOIDCSession(ctx context.Context,
return nil, err return nil, err
} }
} }
return cmd.PushEvents(ctx) postCommit, err := cmd.SetMilestones(ctx, clientID, sessionID != "")
if err != nil {
return nil, err
}
if session, err = cmd.PushEvents(ctx); err != nil {
return nil, err
}
postCommit(ctx)
return session, nil
} }
type RefreshTokenComplianceChecker func(ctx context.Context, wm *OIDCSessionWriteModel, requestedScope []string) (scope []string, err error) type RefreshTokenComplianceChecker func(ctx context.Context, wm *OIDCSessionWriteModel, requestedScope []string) (scope []string, err error)
@@ -283,7 +299,7 @@ func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, userID, resource
} }
sessionID = IDPrefixV2 + sessionID sessionID = IDPrefixV2 + sessionID
return &OIDCSessionEvents{ return &OIDCSessionEvents{
eventstore: c.eventstore, commands: c,
idGenerator: c.idGenerator, idGenerator: c.idGenerator,
encryptionAlg: c.keyAlgorithm, encryptionAlg: c.keyAlgorithm,
events: pending, events: pending,
@@ -341,7 +357,7 @@ func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, refreshToken
return nil, err return nil, err
} }
return &OIDCSessionEvents{ return &OIDCSessionEvents{
eventstore: c.eventstore, commands: c,
idGenerator: c.idGenerator, idGenerator: c.idGenerator,
encryptionAlg: c.keyAlgorithm, encryptionAlg: c.keyAlgorithm,
oidcSessionWriteModel: sessionWriteModel, oidcSessionWriteModel: sessionWriteModel,
@@ -352,7 +368,7 @@ func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, refreshToken
} }
type OIDCSessionEvents struct { type OIDCSessionEvents struct {
eventstore *eventstore.Eventstore commands *Commands
idGenerator id.Generator idGenerator id.Generator
encryptionAlg crypto.EncryptionAlgorithm encryptionAlg crypto.EncryptionAlgorithm
events []eventstore.Command events []eventstore.Command
@@ -467,7 +483,7 @@ func (c *OIDCSessionEvents) generateRefreshToken(userID string) (refreshTokenID,
} }
func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (*OIDCSession, error) { func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (*OIDCSession, error) {
pushedEvents, err := c.eventstore.Push(ctx, c.events...) pushedEvents, err := c.commands.eventstore.Push(ctx, c.events...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -496,7 +512,7 @@ func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (*OIDCSession, error
// we need to use `-` as a delimiter because the OIDC library uses `:` and will check for a length of 2 parts // we need to use `-` as a delimiter because the OIDC library uses `:` and will check for a length of 2 parts
session.TokenID = c.oidcSessionWriteModel.AggregateID + TokenDelimiter + c.accessTokenID session.TokenID = c.oidcSessionWriteModel.AggregateID + TokenDelimiter + c.accessTokenID
} }
activity.Trigger(ctx, c.oidcSessionWriteModel.UserResourceOwner, c.oidcSessionWriteModel.UserID, tokenReasonToActivityMethodType(c.oidcSessionWriteModel.AccessTokenReason), c.eventstore.FilterToQueryReducer) activity.Trigger(ctx, c.oidcSessionWriteModel.UserResourceOwner, c.oidcSessionWriteModel.UserID, tokenReasonToActivityMethodType(c.oidcSessionWriteModel.AccessTokenReason), c.commands.eventstore.FilterToQueryReducer)
return session, nil return session, nil
} }

View File

@@ -70,7 +70,7 @@ func TestCommands_CreateOIDCSessionFromAuthRequest(t *testing.T) {
eventstore: expectEventstore(), eventstore: expectEventstore(),
}, },
args{ args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
authRequestID: "", authRequestID: "",
complianceCheck: mockAuthRequestComplianceChecker(nil), complianceCheck: mockAuthRequestComplianceChecker(nil),
}, },
@@ -86,7 +86,7 @@ func TestCommands_CreateOIDCSessionFromAuthRequest(t *testing.T) {
), ),
}, },
args{ args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
authRequestID: "V2_authRequestID", authRequestID: "V2_authRequestID",
complianceCheck: mockAuthRequestComplianceChecker(nil), complianceCheck: mockAuthRequestComplianceChecker(nil),
}, },
@@ -102,7 +102,7 @@ func TestCommands_CreateOIDCSessionFromAuthRequest(t *testing.T) {
), ),
}, },
args{ args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
authRequestID: "V2_authRequestID", authRequestID: "V2_authRequestID",
complianceCheck: mockAuthRequestComplianceChecker(nil), complianceCheck: mockAuthRequestComplianceChecker(nil),
}, },
@@ -706,6 +706,7 @@ func TestCommands_CreateOIDCSessionFromAuthRequest(t *testing.T) {
defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime, defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime,
keyAlgorithm: tt.fields.keyAlgorithm, keyAlgorithm: tt.fields.keyAlgorithm,
} }
c.setMilestonesCompletedForTest("instanceID")
gotSession, gotState, err := c.CreateOIDCSessionFromAuthRequest(tt.args.ctx, tt.args.authRequestID, tt.args.complianceCheck, tt.args.needRefreshToken) gotSession, gotState, err := c.CreateOIDCSessionFromAuthRequest(tt.args.ctx, tt.args.authRequestID, tt.args.complianceCheck, tt.args.needRefreshToken)
require.ErrorIs(t, err, tt.res.err) require.ErrorIs(t, err, tt.res.err)
@@ -762,7 +763,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
), ),
}, },
args: args{ args: args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
userID: "userID", userID: "userID",
resourceOwner: "orgID", resourceOwner: "orgID",
clientID: "clientID", clientID: "clientID",
@@ -818,7 +819,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
}, },
args: args{ args: args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
userID: "userID", userID: "userID",
resourceOwner: "org1", resourceOwner: "org1",
clientID: "clientID", clientID: "clientID",
@@ -892,7 +893,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
}, },
args: args{ args: args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
userID: "userID", userID: "userID",
resourceOwner: "org1", resourceOwner: "org1",
clientID: "clientID", clientID: "clientID",
@@ -1089,7 +1090,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
}, },
args: args{ args: args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
userID: "userID", userID: "userID",
resourceOwner: "org1", resourceOwner: "org1",
clientID: "clientID", clientID: "clientID",
@@ -1186,7 +1187,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
}, },
args: args{ args: args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
userID: "userID", userID: "userID",
resourceOwner: "org1", resourceOwner: "org1",
clientID: "clientID", clientID: "clientID",
@@ -1266,7 +1267,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
}), }),
}, },
args: args{ args: args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
userID: "userID", userID: "userID",
resourceOwner: "org1", resourceOwner: "org1",
clientID: "clientID", clientID: "clientID",
@@ -1347,7 +1348,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
}), }),
}, },
args: args{ args: args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
userID: "userID", userID: "userID",
resourceOwner: "org1", resourceOwner: "org1",
clientID: "clientID", clientID: "clientID",
@@ -1406,6 +1407,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
keyAlgorithm: tt.fields.keyAlgorithm, keyAlgorithm: tt.fields.keyAlgorithm,
checkPermission: tt.fields.checkPermission, checkPermission: tt.fields.checkPermission,
} }
c.setMilestonesCompletedForTest("instanceID")
got, err := c.CreateOIDCSession(tt.args.ctx, got, err := c.CreateOIDCSession(tt.args.ctx,
tt.args.userID, tt.args.userID,
tt.args.resourceOwner, tt.args.resourceOwner,

View File

@@ -34,7 +34,11 @@ func (c *Commands) AddProjectWithID(ctx context.Context, project *domain.Project
if existingProject.State != domain.ProjectStateUnspecified { if existingProject.State != domain.ProjectStateUnspecified {
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-opamwu", "Errors.Project.AlreadyExisting") return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-opamwu", "Errors.Project.AlreadyExisting")
} }
return c.addProjectWithID(ctx, project, resourceOwner, projectID) project, err = c.addProjectWithID(ctx, project, resourceOwner, projectID)
if err != nil {
return nil, err
}
return project, nil
} }
func (c *Commands) AddProject(ctx context.Context, project *domain.Project, resourceOwner, ownerUserID string) (_ *domain.Project, err error) { func (c *Commands) AddProject(ctx context.Context, project *domain.Project, resourceOwner, ownerUserID string) (_ *domain.Project, err error) {
@@ -53,7 +57,11 @@ func (c *Commands) AddProject(ctx context.Context, project *domain.Project, reso
return nil, err return nil, err
} }
return c.addProjectWithIDWithOwner(ctx, project, resourceOwner, ownerUserID, projectID) project, err = c.addProjectWithIDWithOwner(ctx, project, resourceOwner, ownerUserID, projectID)
if err != nil {
return nil, err
}
return project, nil
} }
func (c *Commands) addProjectWithID(ctx context.Context, projectAdd *domain.Project, resourceOwner, projectID string) (_ *domain.Project, err error) { func (c *Commands) addProjectWithID(ctx context.Context, projectAdd *domain.Project, resourceOwner, projectID string) (_ *domain.Project, err error) {
@@ -71,11 +79,15 @@ func (c *Commands) addProjectWithID(ctx context.Context, projectAdd *domain.Proj
projectAdd.HasProjectCheck, projectAdd.HasProjectCheck,
projectAdd.PrivateLabelingSetting), projectAdd.PrivateLabelingSetting),
} }
postCommit, err := c.projectCreatedMilestone(ctx, &events)
if err != nil {
return nil, err
}
pushedEvents, err := c.eventstore.Push(ctx, events...) pushedEvents, err := c.eventstore.Push(ctx, events...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
postCommit(ctx)
err = AppendAndReduce(addedProject, pushedEvents...) err = AppendAndReduce(addedProject, pushedEvents...)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -103,11 +115,15 @@ func (c *Commands) addProjectWithIDWithOwner(ctx context.Context, projectAdd *do
projectAdd.PrivateLabelingSetting), projectAdd.PrivateLabelingSetting),
project.NewProjectMemberAddedEvent(ctx, projectAgg, ownerUserID, projectRole), project.NewProjectMemberAddedEvent(ctx, projectAgg, ownerUserID, projectRole),
} }
postCommit, err := c.projectCreatedMilestone(ctx, &events)
if err != nil {
return nil, err
}
pushedEvents, err := c.eventstore.Push(ctx, events...) pushedEvents, err := c.eventstore.Push(ctx, events...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
postCommit(ctx)
err = AppendAndReduce(addedProject, pushedEvents...) err = AppendAndReduce(addedProject, pushedEvents...)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -202,10 +202,15 @@ func (c *Commands) addOIDCApplicationWithID(ctx context.Context, oidcApp *domain
)) ))
addedApplication.AppID = oidcApp.AppID addedApplication.AppID = oidcApp.AppID
postCommit, err := c.applicationCreatedMilestone(ctx, &events)
if err != nil {
return nil, err
}
pushedEvents, err := c.eventstore.Push(ctx, events...) pushedEvents, err := c.eventstore.Push(ctx, events...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
postCommit(ctx)
err = AppendAndReduce(addedApplication, pushedEvents...) err = AppendAndReduce(addedApplication, pushedEvents...)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -11,6 +11,7 @@ import (
"github.com/zitadel/passwap" "github.com/zitadel/passwap"
"github.com/zitadel/passwap/bcrypt" "github.com/zitadel/passwap/bcrypt"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/command/preparation" "github.com/zitadel/zitadel/internal/command/preparation"
"github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
@@ -418,7 +419,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
eventstore: expectEventstore(), eventstore: expectEventstore(),
}, },
args: args{ args: args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcApp: &domain.OIDCApp{}, oidcApp: &domain.OIDCApp{},
resourceOwner: "org1", resourceOwner: "org1",
}, },
@@ -434,7 +435,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
), ),
}, },
args: args{ args: args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcApp: &domain.OIDCApp{ oidcApp: &domain.OIDCApp{
ObjectRoot: models.ObjectRoot{ ObjectRoot: models.ObjectRoot{
AggregateID: "project1", AggregateID: "project1",
@@ -463,7 +464,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
), ),
}, },
args: args{ args: args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcApp: &domain.OIDCApp{ oidcApp: &domain.OIDCApp{
ObjectRoot: models.ObjectRoot{ ObjectRoot: models.ObjectRoot{
AggregateID: "project1", AggregateID: "project1",
@@ -521,7 +522,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "app1", "client1"), idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "app1", "client1"),
}, },
args: args{ args: args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcApp: &domain.OIDCApp{ oidcApp: &domain.OIDCApp{
ObjectRoot: models.ObjectRoot{ ObjectRoot: models.ObjectRoot{
AggregateID: "project1", AggregateID: "project1",
@@ -619,7 +620,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "app1", "client1"), idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "app1", "client1"),
}, },
args: args{ args: args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcApp: &domain.OIDCApp{ oidcApp: &domain.OIDCApp{
ObjectRoot: models.ObjectRoot{ ObjectRoot: models.ObjectRoot{
AggregateID: "project1", AggregateID: "project1",
@@ -676,7 +677,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
r := &Commands{ c := &Commands{
eventstore: tt.fields.eventstore(t), eventstore: tt.fields.eventstore(t),
idGenerator: tt.fields.idGenerator, idGenerator: tt.fields.idGenerator,
newHashedSecret: mockHashedSecret("secret"), newHashedSecret: mockHashedSecret("secret"),
@@ -684,7 +685,8 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
ClientSecret: emptyConfig, ClientSecret: emptyConfig,
}, },
} }
got, err := r.AddOIDCApplication(tt.args.ctx, tt.args.oidcApp, tt.args.resourceOwner) c.setMilestonesCompletedForTest("instanceID")
got, err := c.AddOIDCApplication(tt.args.ctx, tt.args.oidcApp, tt.args.resourceOwner)
if tt.res.err == nil { if tt.res.err == nil {
assert.NoError(t, err) assert.NoError(t, err)
} }

View File

@@ -28,10 +28,15 @@ func (c *Commands) AddSAMLApplication(ctx context.Context, application *domain.S
return nil, err return nil, err
} }
addedApplication.AppID = application.AppID addedApplication.AppID = application.AppID
postCommit, err := c.applicationCreatedMilestone(ctx, &events)
if err != nil {
return nil, err
}
pushedEvents, err := c.eventstore.Push(ctx, events...) pushedEvents, err := c.eventstore.Push(ctx, events...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
postCommit(ctx)
err = AppendAndReduce(addedApplication, pushedEvents...) err = AppendAndReduce(addedApplication, pushedEvents...)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/v1/models" "github.com/zitadel/zitadel/internal/eventstore/v1/models"
@@ -76,7 +77,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
), ),
}, },
args: args{ args: args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
samlApp: &domain.SAMLApp{}, samlApp: &domain.SAMLApp{},
resourceOwner: "org1", resourceOwner: "org1",
}, },
@@ -93,7 +94,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
), ),
}, },
args: args{ args: args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
samlApp: &domain.SAMLApp{ samlApp: &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{ ObjectRoot: models.ObjectRoot{
AggregateID: "project1", AggregateID: "project1",
@@ -123,7 +124,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
), ),
}, },
args: args{ args: args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
samlApp: &domain.SAMLApp{ samlApp: &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{ ObjectRoot: models.ObjectRoot{
AggregateID: "project1", AggregateID: "project1",
@@ -154,7 +155,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
idGenerator: id_mock.NewIDGeneratorExpectIDs(t), idGenerator: id_mock.NewIDGeneratorExpectIDs(t),
}, },
args: args{ args: args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
samlApp: &domain.SAMLApp{ samlApp: &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{ ObjectRoot: models.ObjectRoot{
AggregateID: "project1", AggregateID: "project1",
@@ -201,7 +202,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "app1"), idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "app1"),
}, },
args: args{ args: args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
samlApp: &domain.SAMLApp{ samlApp: &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{ ObjectRoot: models.ObjectRoot{
AggregateID: "project1", AggregateID: "project1",
@@ -260,7 +261,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
httpClient: newTestClient(200, testMetadata), httpClient: newTestClient(200, testMetadata),
}, },
args: args{ args: args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
samlApp: &domain.SAMLApp{ samlApp: &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{ ObjectRoot: models.ObjectRoot{
AggregateID: "project1", AggregateID: "project1",
@@ -305,7 +306,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
httpClient: newTestClient(http.StatusNotFound, nil), httpClient: newTestClient(http.StatusNotFound, nil),
}, },
args: args{ args: args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
samlApp: &domain.SAMLApp{ samlApp: &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{ ObjectRoot: models.ObjectRoot{
AggregateID: "project1", AggregateID: "project1",
@@ -325,13 +326,13 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
r := &Commands{ c := &Commands{
eventstore: tt.fields.eventstore, eventstore: tt.fields.eventstore,
idGenerator: tt.fields.idGenerator, idGenerator: tt.fields.idGenerator,
httpClient: tt.fields.httpClient, httpClient: tt.fields.httpClient,
} }
c.setMilestonesCompletedForTest("instanceID")
got, err := r.AddSAMLApplication(tt.args.ctx, tt.args.samlApp, tt.args.resourceOwner) got, err := c.AddSAMLApplication(tt.args.ctx, tt.args.samlApp, tt.args.resourceOwner)
if tt.res.err == nil { if tt.res.err == nil {
assert.NoError(t, err) assert.NoError(t, err)
} }

View File

@@ -6,6 +6,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/v1/models" "github.com/zitadel/zitadel/internal/eventstore/v1/models"
@@ -44,7 +45,7 @@ func TestCommandSide_AddProject(t *testing.T) {
), ),
}, },
args: args{ args: args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
project: &domain.Project{}, project: &domain.Project{},
resourceOwner: "org1", resourceOwner: "org1",
}, },
@@ -60,7 +61,7 @@ func TestCommandSide_AddProject(t *testing.T) {
), ),
}, },
args: args{ args: args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
project: &domain.Project{ project: &domain.Project{
Name: "project", Name: "project",
ProjectRoleAssertion: true, ProjectRoleAssertion: true,
@@ -121,7 +122,7 @@ func TestCommandSide_AddProject(t *testing.T) {
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "project1"), idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "project1"),
}, },
args: args{ args: args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
project: &domain.Project{ project: &domain.Project{
Name: "project", Name: "project",
ProjectRoleAssertion: true, ProjectRoleAssertion: true,
@@ -159,7 +160,7 @@ func TestCommandSide_AddProject(t *testing.T) {
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "project1"), idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "project1"),
}, },
args: args{ args: args{
ctx: context.Background(), ctx: authz.WithInstanceID(context.Background(), "instanceID"),
project: &domain.Project{ project: &domain.Project{
Name: "project", Name: "project",
ProjectRoleAssertion: true, ProjectRoleAssertion: true,
@@ -187,11 +188,12 @@ func TestCommandSide_AddProject(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
r := &Commands{ c := &Commands{
eventstore: tt.fields.eventstore, eventstore: tt.fields.eventstore,
idGenerator: tt.fields.idGenerator, idGenerator: tt.fields.idGenerator,
} }
got, err := r.AddProject(tt.args.ctx, tt.args.project, tt.args.resourceOwner, tt.args.ownerID) c.setMilestonesCompletedForTest("instanceID")
got, err := c.AddProject(tt.args.ctx, tt.args.project, tt.args.resourceOwner, tt.args.ownerID)
if tt.res.err == nil { if tt.res.err == nil {
assert.NoError(t, err) assert.NoError(t, err)
} }

View File

@@ -11,6 +11,7 @@ import (
"github.com/zitadel/logging" "github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/internal/zerrors"
) )
@@ -247,6 +248,10 @@ func (es *Eventstore) InstanceIDs(ctx context.Context, maxAge time.Duration, for
return instances, nil return instances, nil
} }
func (es *Eventstore) Client() *database.DB {
return es.querier.Client()
}
type QueryReducer interface { type QueryReducer interface {
reducer reducer
//Query returns the SearchQueryFactory for the events needed in reducer //Query returns the SearchQueryFactory for the events needed in reducer
@@ -270,6 +275,8 @@ type Querier interface {
LatestSequence(ctx context.Context, queryFactory *SearchQueryBuilder) (float64, error) LatestSequence(ctx context.Context, queryFactory *SearchQueryBuilder) (float64, error)
// InstanceIDs returns the instance ids found by the search query // InstanceIDs returns the instance ids found by the search query
InstanceIDs(ctx context.Context, queryFactory *SearchQueryBuilder) ([]string, error) InstanceIDs(ctx context.Context, queryFactory *SearchQueryBuilder) ([]string, error)
// Client returns the underlying database connection
Client() *database.DB
} }
type Pusher interface { type Pusher interface {

View File

@@ -12,6 +12,7 @@ import (
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/service" "github.com/zitadel/zitadel/internal/api/service"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/internal/zerrors"
) )
@@ -437,6 +438,10 @@ func (repo *testQuerier) InstanceIDs(ctx context.Context, queryFactory *SearchQu
return repo.instances, nil return repo.instances, nil
} }
func (*testQuerier) Client() *database.DB {
return nil
}
func TestEventstore_Push(t *testing.T) { func TestEventstore_Push(t *testing.T) {
type args struct { type args struct {
events []Command events []Command

View File

@@ -13,6 +13,7 @@ import (
context "context" context "context"
reflect "reflect" reflect "reflect"
database "github.com/zitadel/zitadel/internal/database"
eventstore "github.com/zitadel/zitadel/internal/eventstore" eventstore "github.com/zitadel/zitadel/internal/eventstore"
gomock "go.uber.org/mock/gomock" gomock "go.uber.org/mock/gomock"
) )
@@ -40,6 +41,20 @@ func (m *MockQuerier) EXPECT() *MockQuerierMockRecorder {
return m.recorder return m.recorder
} }
// Client mocks base method.
func (m *MockQuerier) Client() *database.DB {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Client")
ret0, _ := ret[0].(*database.DB)
return ret0
}
// Client indicates an expected call of Client.
func (mr *MockQuerierMockRecorder) Client() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Client", reflect.TypeOf((*MockQuerier)(nil).Client))
}
// FilterToReducer mocks base method. // FilterToReducer mocks base method.
func (m *MockQuerier) FilterToReducer(arg0 context.Context, arg1 *eventstore.SearchQueryBuilder, arg2 eventstore.Reducer) error { func (m *MockQuerier) FilterToReducer(arg0 context.Context, arg1 *eventstore.SearchQueryBuilder, arg2 eventstore.Reducer) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@@ -282,7 +282,7 @@ func (db *CRDB) InstanceIDs(ctx context.Context, searchQuery *eventstore.SearchQ
return ids, nil return ids, nil
} }
func (db *CRDB) db() *database.DB { func (db *CRDB) Client() *database.DB {
return db.DB return db.DB
} }

View File

@@ -27,7 +27,7 @@ type querier interface {
eventQuery(useV1 bool) string eventQuery(useV1 bool) string
maxSequenceQuery(useV1 bool) string maxSequenceQuery(useV1 bool) string
instanceIDsQuery(useV1 bool) string instanceIDsQuery(useV1 bool) string
db() *database.DB Client() *database.DB
orderByEventSequence(desc, shouldOrderBySequence, useV1 bool) string orderByEventSequence(desc, shouldOrderBySequence, useV1 bool) string
dialect.Database dialect.Database
} }
@@ -110,7 +110,7 @@ func query(ctx context.Context, criteria querier, searchQuery *eventstore.Search
var contextQuerier interface { var contextQuerier interface {
QueryContext(context.Context, func(rows *sql.Rows) error, string, ...interface{}) error QueryContext(context.Context, func(rows *sql.Rows) error, string, ...interface{}) error
} }
contextQuerier = criteria.db() contextQuerier = criteria.Client()
if q.Tx != nil { if q.Tx != nil {
contextQuerier = &tx{Tx: q.Tx} contextQuerier = &tx{Tx: q.Tx}
} }

View File

@@ -22,5 +22,5 @@ type Commands interface {
HumanPhoneVerificationCodeSent(ctx context.Context, orgID, userID string, generatorInfo *senders.CodeGeneratorInfo) error HumanPhoneVerificationCodeSent(ctx context.Context, orgID, userID string, generatorInfo *senders.CodeGeneratorInfo) error
InviteCodeSent(ctx context.Context, orgID, userID string) error InviteCodeSent(ctx context.Context, orgID, userID string) error
UsageNotificationSent(ctx context.Context, dueEvent *quota.NotificationDueEvent) error UsageNotificationSent(ctx context.Context, dueEvent *quota.NotificationDueEvent) error
MilestonePushed(ctx context.Context, msType milestone.Type, endpoints []string, primaryDomain string) error MilestonePushed(ctx context.Context, instanceID string, msType milestone.Type, endpoints []string) error
} }

View File

@@ -16,6 +16,7 @@ import (
"github.com/zitadel/zitadel/internal/integration" "github.com/zitadel/zitadel/internal/integration"
"github.com/zitadel/zitadel/internal/integration/sink" "github.com/zitadel/zitadel/internal/integration/sink"
"github.com/zitadel/zitadel/internal/repository/milestone"
"github.com/zitadel/zitadel/pkg/grpc/app" "github.com/zitadel/zitadel/pkg/grpc/app"
"github.com/zitadel/zitadel/pkg/grpc/management" "github.com/zitadel/zitadel/pkg/grpc/management"
"github.com/zitadel/zitadel/pkg/grpc/object" "github.com/zitadel/zitadel/pkg/grpc/object"
@@ -32,12 +33,12 @@ func TestServer_TelemetryPushMilestones(t *testing.T) {
instance := integration.NewInstance(CTX) instance := integration.NewInstance(CTX)
iamOwnerCtx := instance.WithAuthorization(CTX, integration.UserTypeIAMOwner) iamOwnerCtx := instance.WithAuthorization(CTX, integration.UserTypeIAMOwner)
t.Log("testing against instance with primary domain", instance.Domain) t.Log("testing against instance", instance.ID())
awaitMilestone(t, sub, instance.Domain, "InstanceCreated") awaitMilestone(t, sub, instance.ID(), milestone.InstanceCreated)
projectAdded, err := instance.Client.Mgmt.AddProject(iamOwnerCtx, &management.AddProjectRequest{Name: "integration"}) projectAdded, err := instance.Client.Mgmt.AddProject(iamOwnerCtx, &management.AddProjectRequest{Name: "integration"})
require.NoError(t, err) require.NoError(t, err)
awaitMilestone(t, sub, instance.Domain, "ProjectCreated") awaitMilestone(t, sub, instance.ID(), milestone.ProjectCreated)
redirectURI := "http://localhost:8888" redirectURI := "http://localhost:8888"
application, err := instance.Client.Mgmt.AddOIDCApp(iamOwnerCtx, &management.AddOIDCAppRequest{ application, err := instance.Client.Mgmt.AddOIDCApp(iamOwnerCtx, &management.AddOIDCAppRequest{
@@ -52,14 +53,14 @@ func TestServer_TelemetryPushMilestones(t *testing.T) {
AccessTokenType: app.OIDCTokenType_OIDC_TOKEN_TYPE_JWT, AccessTokenType: app.OIDCTokenType_OIDC_TOKEN_TYPE_JWT,
}) })
require.NoError(t, err) require.NoError(t, err)
awaitMilestone(t, sub, instance.Domain, "ApplicationCreated") awaitMilestone(t, sub, instance.ID(), milestone.ApplicationCreated)
// create the session to be used for the authN of the clients // create the session to be used for the authN of the clients
sessionID, sessionToken, _, _ := instance.CreatePasswordSession(t, iamOwnerCtx, instance.AdminUserID, "Password1!") sessionID, sessionToken, _, _ := instance.CreatePasswordSession(t, iamOwnerCtx, instance.AdminUserID, "Password1!")
console := consoleOIDCConfig(t, instance) console := consoleOIDCConfig(t, instance)
loginToClient(t, instance, console.GetClientId(), console.GetRedirectUris()[0], sessionID, sessionToken) loginToClient(t, instance, console.GetClientId(), console.GetRedirectUris()[0], sessionID, sessionToken)
awaitMilestone(t, sub, instance.Domain, "AuthenticationSucceededOnInstance") awaitMilestone(t, sub, instance.ID(), milestone.AuthenticationSucceededOnInstance)
// make sure the client has been projected // make sure the client has been projected
require.EventuallyWithT(t, func(collectT *assert.CollectT) { require.EventuallyWithT(t, func(collectT *assert.CollectT) {
@@ -70,11 +71,11 @@ func TestServer_TelemetryPushMilestones(t *testing.T) {
assert.NoError(collectT, err) assert.NoError(collectT, err)
}, time.Minute, time.Second, "app not found") }, time.Minute, time.Second, "app not found")
loginToClient(t, instance, application.GetClientId(), redirectURI, sessionID, sessionToken) loginToClient(t, instance, application.GetClientId(), redirectURI, sessionID, sessionToken)
awaitMilestone(t, sub, instance.Domain, "AuthenticationSucceededOnApplication") awaitMilestone(t, sub, instance.ID(), milestone.AuthenticationSucceededOnApplication)
_, err = integration.SystemClient().RemoveInstance(CTX, &system.RemoveInstanceRequest{InstanceId: instance.ID()}) _, err = integration.SystemClient().RemoveInstance(CTX, &system.RemoveInstanceRequest{InstanceId: instance.ID()})
require.NoError(t, err) require.NoError(t, err)
awaitMilestone(t, sub, instance.Domain, "InstanceDeleted") awaitMilestone(t, sub, instance.ID(), milestone.InstanceDeleted)
} }
func loginToClient(t *testing.T, instance *integration.Instance, clientID, redirectURI, sessionID, sessionToken string) { func loginToClient(t *testing.T, instance *integration.Instance, clientID, redirectURI, sessionID, sessionToken string) {
@@ -134,7 +135,7 @@ func consoleOIDCConfig(t *testing.T, instance *integration.Instance) *app.OIDCCo
return apps.GetResult()[0].GetOidcConfig() return apps.GetResult()[0].GetOidcConfig()
} }
func awaitMilestone(t *testing.T, sub *sink.Subscription, primaryDomain, expectMilestoneType string) { func awaitMilestone(t *testing.T, sub *sink.Subscription, instanceID string, expectMilestoneType milestone.Type) {
for { for {
select { select {
case req := <-sub.Recv(): case req := <-sub.Recv():
@@ -144,17 +145,17 @@ func awaitMilestone(t *testing.T, sub *sink.Subscription, primaryDomain, expectM
} }
t.Log("received milestone", plain.String()) t.Log("received milestone", plain.String())
milestone := struct { milestone := struct {
Type string `json:"type"` InstanceID string `json:"instanceId"`
PrimaryDomain string `json:"primaryDomain"` Type milestone.Type `json:"type"`
}{} }{}
if err := json.Unmarshal(req.Body, &milestone); err != nil { if err := json.Unmarshal(req.Body, &milestone); err != nil {
t.Error(err) t.Error(err)
} }
if milestone.Type == expectMilestoneType && milestone.PrimaryDomain == primaryDomain { if milestone.Type == expectMilestoneType && milestone.InstanceID == instanceID {
return return
} }
case <-time.After(2 * time.Minute): // why does it take so long to get a milestone !? case <-time.After(20 * time.Second):
t.Fatalf("timed out waiting for milestone %s in domain %s", expectMilestoneType, primaryDomain) t.Fatalf("timed out waiting for milestone %s for instance %s", expectMilestoneType, instanceID)
} }
} }
} }

View File

@@ -141,7 +141,7 @@ func (mr *MockCommandsMockRecorder) InviteCodeSent(arg0, arg1, arg2 any) *gomock
} }
// MilestonePushed mocks base method. // MilestonePushed mocks base method.
func (m *MockCommands) MilestonePushed(arg0 context.Context, arg1 milestone.Type, arg2 []string, arg3 string) error { func (m *MockCommands) MilestonePushed(arg0 context.Context, arg1 string, arg2 milestone.Type, arg3 []string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MilestonePushed", arg0, arg1, arg2, arg3) ret := m.ctrl.Call(m, "MilestonePushed", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)

View File

@@ -2,13 +2,9 @@ package handlers
import ( import (
"context" "context"
"fmt"
"net/http" "net/http"
"time" "time"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
@@ -16,9 +12,7 @@ import (
"github.com/zitadel/zitadel/internal/notification/channels/webhook" "github.com/zitadel/zitadel/internal/notification/channels/webhook"
_ "github.com/zitadel/zitadel/internal/notification/statik" _ "github.com/zitadel/zitadel/internal/notification/statik"
"github.com/zitadel/zitadel/internal/notification/types" "github.com/zitadel/zitadel/internal/notification/types"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/repository/milestone" "github.com/zitadel/zitadel/internal/repository/milestone"
"github.com/zitadel/zitadel/internal/repository/pseudo"
"github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/internal/zerrors"
) )
@@ -30,7 +24,6 @@ type TelemetryPusherConfig struct {
Enabled bool Enabled bool
Endpoints []string Endpoints []string
Headers http.Header Headers http.Header
Limit uint64
} }
type telemetryPusher struct { type telemetryPusher struct {
@@ -54,7 +47,6 @@ func NewTelemetryPusher(
queries: queries, queries: queries,
channels: channels, channels: channels,
} }
handlerCfg.TriggerWithoutEvents = pusher.pushMilestones
return handler.NewHandler( return handler.NewHandler(
ctx, ctx,
&handlerCfg, &handlerCfg,
@@ -68,9 +60,9 @@ func (u *telemetryPusher) Name() string {
func (t *telemetryPusher) Reducers() []handler.AggregateReducer { func (t *telemetryPusher) Reducers() []handler.AggregateReducer {
return []handler.AggregateReducer{{ return []handler.AggregateReducer{{
Aggregate: pseudo.AggregateType, Aggregate: milestone.AggregateType,
EventReducers: []handler.EventReducer{{ EventReducers: []handler.EventReducer{{
Event: pseudo.ScheduledEventType, Event: milestone.ReachedEventType,
Reduce: t.pushMilestones, Reduce: t.pushMilestones,
}}, }},
}} }}
@@ -78,51 +70,20 @@ func (t *telemetryPusher) Reducers() []handler.AggregateReducer {
func (t *telemetryPusher) pushMilestones(event eventstore.Event) (*handler.Statement, error) { func (t *telemetryPusher) pushMilestones(event eventstore.Event) (*handler.Statement, error) {
ctx := call.WithTimestamp(context.Background()) ctx := call.WithTimestamp(context.Background())
scheduledEvent, ok := event.(*pseudo.ScheduledEvent) e, ok := event.(*milestone.ReachedEvent)
if !ok { if !ok {
return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-lDTs5", "reduce.wrong.event.type %s", event.Type()) return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-lDTs5", "reduce.wrong.event.type %s", event.Type())
} }
return handler.NewStatement(event, func(handler.Executer, string) error {
return handler.NewStatement(event, func(ex handler.Executer, projectionName string) error { // Do not push the milestone again if this was a migration event.
isReached, err := query.NewNotNullQuery(query.MilestoneReachedDateColID) if e.ReachedDate != nil {
if err != nil { return nil
return err
} }
isNotPushed, err := query.NewIsNullQuery(query.MilestonePushedDateColID) return t.pushMilestone(ctx, e)
if err != nil {
return err
}
hasPrimaryDomain, err := query.NewNotNullQuery(query.MilestonePrimaryDomainColID)
if err != nil {
return err
}
unpushedMilestones, err := t.queries.Queries.SearchMilestones(ctx, scheduledEvent.InstanceIDs, &query.MilestonesSearchQueries{
SearchRequest: query.SearchRequest{
Limit: t.cfg.Limit,
SortingColumn: query.MilestoneReachedDateColID,
Asc: true,
},
Queries: []query.SearchQuery{isReached, isNotPushed, hasPrimaryDomain},
})
if err != nil {
return err
}
var errs int
for _, ms := range unpushedMilestones.Milestones {
if err = t.pushMilestone(ctx, scheduledEvent, ms); err != nil {
errs++
logging.Warnf("pushing milestone %+v failed: %s", *ms, err.Error())
}
}
if errs > 0 {
return fmt.Errorf("pushing %d of %d milestones failed", errs, unpushedMilestones.Count)
}
return nil
}), nil }), nil
} }
func (t *telemetryPusher) pushMilestone(ctx context.Context, event *pseudo.ScheduledEvent, ms *query.Milestone) error { func (t *telemetryPusher) pushMilestone(ctx context.Context, e *milestone.ReachedEvent) error {
ctx = authz.WithInstanceID(ctx, ms.InstanceID)
for _, endpoint := range t.cfg.Endpoints { for _, endpoint := range t.cfg.Endpoints {
if err := types.SendJSON( if err := types.SendJSON(
ctx, ctx,
@@ -135,20 +96,18 @@ func (t *telemetryPusher) pushMilestone(ctx context.Context, event *pseudo.Sched
&struct { &struct {
InstanceID string `json:"instanceId"` InstanceID string `json:"instanceId"`
ExternalDomain string `json:"externalDomain"` ExternalDomain string `json:"externalDomain"`
PrimaryDomain string `json:"primaryDomain"`
Type milestone.Type `json:"type"` Type milestone.Type `json:"type"`
ReachedDate time.Time `json:"reached"` ReachedDate time.Time `json:"reached"`
}{ }{
InstanceID: ms.InstanceID, InstanceID: e.Agg.InstanceID,
ExternalDomain: t.queries.externalDomain, ExternalDomain: t.queries.externalDomain,
PrimaryDomain: ms.PrimaryDomain, Type: e.MilestoneType,
Type: ms.Type, ReachedDate: e.GetReachedDate(),
ReachedDate: ms.ReachedDate,
}, },
event, e,
).WithoutTemplate(); err != nil { ).WithoutTemplate(); err != nil {
return err return err
} }
} }
return t.commands.MilestonePushed(ctx, ms.Type, t.cfg.Endpoints, ms.PrimaryDomain) return t.commands.MilestonePushed(ctx, e.Agg.InstanceID, e.MilestoneType, t.cfg.Endpoints)
} }

View File

@@ -54,10 +54,6 @@ var (
name: projection.MilestoneColumnType, name: projection.MilestoneColumnType,
table: milestonesTable, table: milestonesTable,
} }
MilestonePrimaryDomainColID = Column{
name: projection.MilestoneColumnPrimaryDomain,
table: milestonesTable,
}
MilestoneReachedDateColID = Column{ MilestoneReachedDateColID = Column{
name: projection.MilestoneColumnReachedDate, name: projection.MilestoneColumnReachedDate,
table: milestonesTable, table: milestonesTable,
@@ -76,7 +72,10 @@ func (q *Queries) SearchMilestones(ctx context.Context, instanceIDs []string, qu
if len(instanceIDs) == 0 { if len(instanceIDs) == 0 {
instanceIDs = []string{authz.GetInstance(ctx).InstanceID()} instanceIDs = []string{authz.GetInstance(ctx).InstanceID()}
} }
stmt, args, err := queries.toQuery(query).Where(sq.Eq{MilestoneInstanceIDColID.identifier(): instanceIDs}).ToSql() stmt, args, err := queries.toQuery(query).Where(
sq.Eq{MilestoneInstanceIDColID.identifier(): instanceIDs},
sq.Eq{InstanceDomainIsPrimaryCol.identifier(): true},
).ToSql()
if err != nil { if err != nil {
return nil, zerrors.ThrowInternal(err, "QUERY-A9i5k", "Errors.Query.SQLStatement") return nil, zerrors.ThrowInternal(err, "QUERY-A9i5k", "Errors.Query.SQLStatement")
} }
@@ -96,13 +95,14 @@ func (q *Queries) SearchMilestones(ctx context.Context, instanceIDs []string, qu
func prepareMilestonesQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Milestones, error)) { func prepareMilestonesQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Milestones, error)) {
return sq.Select( return sq.Select(
MilestoneInstanceIDColID.identifier(), MilestoneInstanceIDColID.identifier(),
MilestonePrimaryDomainColID.identifier(), InstanceDomainDomainCol.identifier(),
MilestoneReachedDateColID.identifier(), MilestoneReachedDateColID.identifier(),
MilestonePushedDateColID.identifier(), MilestonePushedDateColID.identifier(),
MilestoneTypeColID.identifier(), MilestoneTypeColID.identifier(),
countColumn.identifier(), countColumn.identifier(),
). ).
From(milestonesTable.identifier() + db.Timetravel(call.Took(ctx))). From(milestonesTable.identifier() + db.Timetravel(call.Took(ctx))).
LeftJoin(join(InstanceDomainInstanceIDCol, MilestoneInstanceIDColID)).
PlaceholderFormat(sq.Dollar), PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*Milestones, error) { func(rows *sql.Rows) (*Milestones, error) {
milestones := make([]*Milestone, 0) milestones := make([]*Milestone, 0)

View File

@@ -11,13 +11,14 @@ import (
var ( var (
expectedMilestoneQuery = regexp.QuoteMeta(` expectedMilestoneQuery = regexp.QuoteMeta(`
SELECT projections.milestones.instance_id, SELECT projections.milestones2.instance_id,
projections.milestones.primary_domain, projections.instance_domains.domain,
projections.milestones.reached_date, projections.milestones2.reached_date,
projections.milestones.last_pushed_date, projections.milestones2.last_pushed_date,
projections.milestones.type, projections.milestones2.type,
COUNT(*) OVER () COUNT(*) OVER ()
FROM projections.milestones AS OF SYSTEM TIME '-1 ms' FROM projections.milestones2 AS OF SYSTEM TIME '-1 ms'
LEFT JOIN projections.instance_domains ON projections.milestones2.instance_id = projections.instance_domains.instance_id
`) `)
milestoneCols = []string{ milestoneCols = []string{

View File

@@ -14,8 +14,9 @@ func testEvent(
eventType eventstore.EventType, eventType eventstore.EventType,
aggregateType eventstore.AggregateType, aggregateType eventstore.AggregateType,
data []byte, data []byte,
opts ...eventOption,
) *repository.Event { ) *repository.Event {
return timedTestEvent(eventType, aggregateType, data, time.Now()) return timedTestEvent(eventType, aggregateType, data, time.Now(), opts...)
} }
func toSystemEvent(event *repository.Event) *repository.Event { func toSystemEvent(event *repository.Event) *repository.Event {
@@ -28,8 +29,9 @@ func timedTestEvent(
aggregateType eventstore.AggregateType, aggregateType eventstore.AggregateType,
data []byte, data []byte,
creationDate time.Time, creationDate time.Time,
opts ...eventOption,
) *repository.Event { ) *repository.Event {
return &repository.Event{ e := &repository.Event{
Seq: 15, Seq: 15,
CreationDate: creationDate, CreationDate: creationDate,
Typ: eventType, Typ: eventType,
@@ -42,6 +44,18 @@ func timedTestEvent(
ID: "event-id", ID: "event-id",
EditorUser: "editor-user", EditorUser: "editor-user",
} }
for _, opt := range opts {
opt(e)
}
return e
}
type eventOption func(e *repository.Event)
func withVersion(v eventstore.Version) eventOption {
return func(e *repository.Event) {
e.Version = v
}
} }
func baseEvent(*testing.T) eventstore.Event { func baseEvent(*testing.T) eventstore.Event {

View File

@@ -2,35 +2,26 @@ package projection
import ( import (
"context" "context"
"strconv"
internal_authz "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
old_handler "github.com/zitadel/zitadel/internal/eventstore/handler" old_handler "github.com/zitadel/zitadel/internal/eventstore/handler"
"github.com/zitadel/zitadel/internal/eventstore/handler/v2" "github.com/zitadel/zitadel/internal/eventstore/handler/v2"
"github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/milestone" "github.com/zitadel/zitadel/internal/repository/milestone"
"github.com/zitadel/zitadel/internal/repository/oidcsession"
"github.com/zitadel/zitadel/internal/repository/project"
) )
const ( const (
MilestonesProjectionTable = "projections.milestones" MilestonesProjectionTable = "projections.milestones2"
MilestoneColumnInstanceID = "instance_id" MilestoneColumnInstanceID = "instance_id"
MilestoneColumnType = "type" MilestoneColumnType = "type"
MilestoneColumnPrimaryDomain = "primary_domain" MilestoneColumnReachedDate = "reached_date"
MilestoneColumnReachedDate = "reached_date" MilestoneColumnPushedDate = "last_pushed_date"
MilestoneColumnPushedDate = "last_pushed_date"
MilestoneColumnIgnoreClientIDs = "ignore_client_ids"
) )
type milestoneProjection struct { type milestoneProjection struct{}
systemUsers map[string]*internal_authz.SystemAPIUser
}
func newMilestoneProjection(ctx context.Context, config handler.Config, systemUsers map[string]*internal_authz.SystemAPIUser) *handler.Handler { func newMilestoneProjection(ctx context.Context, config handler.Config) *handler.Handler {
return handler.NewHandler(ctx, &config, &milestoneProjection{systemUsers: systemUsers}) return handler.NewHandler(ctx, &config, &milestoneProjection{})
} }
func (*milestoneProjection) Name() string { func (*milestoneProjection) Name() string {
@@ -44,8 +35,6 @@ func (*milestoneProjection) Init() *old_handler.Check {
handler.NewColumn(MilestoneColumnType, handler.ColumnTypeEnum), handler.NewColumn(MilestoneColumnType, handler.ColumnTypeEnum),
handler.NewColumn(MilestoneColumnReachedDate, handler.ColumnTypeTimestamp, handler.Nullable()), handler.NewColumn(MilestoneColumnReachedDate, handler.ColumnTypeTimestamp, handler.Nullable()),
handler.NewColumn(MilestoneColumnPushedDate, handler.ColumnTypeTimestamp, handler.Nullable()), handler.NewColumn(MilestoneColumnPushedDate, handler.ColumnTypeTimestamp, handler.Nullable()),
handler.NewColumn(MilestoneColumnPrimaryDomain, handler.ColumnTypeText, handler.Nullable()),
handler.NewColumn(MilestoneColumnIgnoreClientIDs, handler.ColumnTypeTextArray, handler.Nullable()),
}, },
handler.NewPrimaryKey(MilestoneColumnInstanceID, MilestoneColumnType), handler.NewPrimaryKey(MilestoneColumnInstanceID, MilestoneColumnType),
), ),
@@ -55,183 +44,47 @@ func (*milestoneProjection) Init() *old_handler.Check {
// Reducers implements handler.Projection. // Reducers implements handler.Projection.
func (p *milestoneProjection) Reducers() []handler.AggregateReducer { func (p *milestoneProjection) Reducers() []handler.AggregateReducer {
return []handler.AggregateReducer{ return []handler.AggregateReducer{
{
Aggregate: instance.AggregateType,
EventReducers: []handler.EventReducer{
{
Event: instance.InstanceAddedEventType,
Reduce: p.reduceInstanceAdded,
},
{
Event: instance.InstanceDomainPrimarySetEventType,
Reduce: p.reduceInstanceDomainPrimarySet,
},
{
Event: instance.InstanceRemovedEventType,
Reduce: p.reduceInstanceRemoved,
},
},
},
{
Aggregate: project.AggregateType,
EventReducers: []handler.EventReducer{
{
Event: project.ProjectAddedType,
Reduce: p.reduceProjectAdded,
},
{
Event: project.ApplicationAddedType,
Reduce: p.reduceApplicationAdded,
},
{
Event: project.OIDCConfigAddedType,
Reduce: p.reduceOIDCConfigAdded,
},
{
Event: project.APIConfigAddedType,
Reduce: p.reduceAPIConfigAdded,
},
},
},
{
Aggregate: oidcsession.AggregateType,
EventReducers: []handler.EventReducer{
{
Event: oidcsession.AddedType,
Reduce: p.reduceOIDCSessionAdded,
},
},
},
{ {
Aggregate: milestone.AggregateType, Aggregate: milestone.AggregateType,
EventReducers: []handler.EventReducer{ EventReducers: []handler.EventReducer{
{
Event: milestone.ReachedEventType,
Reduce: p.reduceReached,
},
{ {
Event: milestone.PushedEventType, Event: milestone.PushedEventType,
Reduce: p.reduceMilestonePushed, Reduce: p.reducePushed,
}, },
}, },
}, },
} }
} }
func (p *milestoneProjection) reduceInstanceAdded(event eventstore.Event) (*handler.Statement, error) { func (p *milestoneProjection) reduceReached(event eventstore.Event) (*handler.Statement, error) {
e, err := assertEvent[*instance.InstanceAddedEvent](event) e, err := assertEvent[*milestone.ReachedEvent](event)
if err != nil { if err != nil {
return nil, err return nil, err
} }
allTypes := milestone.AllTypes() return handler.NewCreateStatement(event, []handler.Column{
statements := make([]func(eventstore.Event) handler.Exec, 0, len(allTypes)) handler.NewCol(MilestoneColumnInstanceID, e.Agg.InstanceID),
for _, msType := range allTypes { handler.NewCol(MilestoneColumnType, e.MilestoneType),
createColumns := []handler.Column{ handler.NewCol(MilestoneColumnReachedDate, e.GetReachedDate()),
handler.NewCol(MilestoneColumnInstanceID, e.Aggregate().InstanceID), }), nil
handler.NewCol(MilestoneColumnType, msType),
}
if msType == milestone.InstanceCreated {
createColumns = append(createColumns, handler.NewCol(MilestoneColumnReachedDate, event.CreatedAt()))
}
statements = append(statements, handler.AddCreateStatement(createColumns))
}
return handler.NewMultiStatement(e, statements...), nil
} }
func (p *milestoneProjection) reduceInstanceDomainPrimarySet(event eventstore.Event) (*handler.Statement, error) { func (p *milestoneProjection) reducePushed(event eventstore.Event) (*handler.Statement, error) {
e, err := assertEvent[*instance.DomainPrimarySetEvent](event)
if err != nil {
return nil, err
}
return handler.NewUpdateStatement(
e,
[]handler.Column{
handler.NewCol(MilestoneColumnPrimaryDomain, e.Domain),
},
[]handler.Condition{
handler.NewCond(MilestoneColumnInstanceID, e.Aggregate().InstanceID),
handler.NewIsNullCond(MilestoneColumnPushedDate),
},
), nil
}
func (p *milestoneProjection) reduceProjectAdded(event eventstore.Event) (*handler.Statement, error) {
if _, err := assertEvent[*project.ProjectAddedEvent](event); err != nil {
return nil, err
}
return p.reduceReachedIfUserEventFunc(milestone.ProjectCreated)(event)
}
func (p *milestoneProjection) reduceApplicationAdded(event eventstore.Event) (*handler.Statement, error) {
if _, err := assertEvent[*project.ApplicationAddedEvent](event); err != nil {
return nil, err
}
return p.reduceReachedIfUserEventFunc(milestone.ApplicationCreated)(event)
}
func (p *milestoneProjection) reduceOIDCConfigAdded(event eventstore.Event) (*handler.Statement, error) {
e, err := assertEvent[*project.OIDCConfigAddedEvent](event)
if err != nil {
return nil, err
}
return p.reduceAppConfigAdded(e, e.ClientID)
}
func (p *milestoneProjection) reduceAPIConfigAdded(event eventstore.Event) (*handler.Statement, error) {
e, err := assertEvent[*project.APIConfigAddedEvent](event)
if err != nil {
return nil, err
}
return p.reduceAppConfigAdded(e, e.ClientID)
}
func (p *milestoneProjection) reduceOIDCSessionAdded(event eventstore.Event) (*handler.Statement, error) {
e, err := assertEvent[*oidcsession.AddedEvent](event)
if err != nil {
return nil, err
}
statements := []func(eventstore.Event) handler.Exec{
handler.AddUpdateStatement(
[]handler.Column{
handler.NewCol(MilestoneColumnReachedDate, event.CreatedAt()),
},
[]handler.Condition{
handler.NewCond(MilestoneColumnInstanceID, event.Aggregate().InstanceID),
handler.NewCond(MilestoneColumnType, milestone.AuthenticationSucceededOnInstance),
handler.NewIsNullCond(MilestoneColumnReachedDate),
},
),
}
// We ignore authentications without app, for example JWT profile or PAT
if e.ClientID != "" {
statements = append(statements, handler.AddUpdateStatement(
[]handler.Column{
handler.NewCol(MilestoneColumnReachedDate, event.CreatedAt()),
},
[]handler.Condition{
handler.NewCond(MilestoneColumnInstanceID, event.Aggregate().InstanceID),
handler.NewCond(MilestoneColumnType, milestone.AuthenticationSucceededOnApplication),
handler.Not(handler.NewTextArrayContainsCond(MilestoneColumnIgnoreClientIDs, e.ClientID)),
handler.NewIsNullCond(MilestoneColumnReachedDate),
},
))
}
return handler.NewMultiStatement(e, statements...), nil
}
func (p *milestoneProjection) reduceInstanceRemoved(event eventstore.Event) (*handler.Statement, error) {
if _, err := assertEvent[*instance.InstanceRemovedEvent](event); err != nil {
return nil, err
}
return p.reduceReachedFunc(milestone.InstanceDeleted)(event)
}
func (p *milestoneProjection) reduceMilestonePushed(event eventstore.Event) (*handler.Statement, error) {
e, err := assertEvent[*milestone.PushedEvent](event) e, err := assertEvent[*milestone.PushedEvent](event)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if e.Agg.Version != milestone.AggregateVersion {
return handler.NewNoOpStatement(event), nil // Skip v1 events.
}
if e.MilestoneType != milestone.InstanceDeleted { if e.MilestoneType != milestone.InstanceDeleted {
return handler.NewUpdateStatement( return handler.NewUpdateStatement(
event, event,
[]handler.Column{ []handler.Column{
handler.NewCol(MilestoneColumnPushedDate, event.CreatedAt()), handler.NewCol(MilestoneColumnPushedDate, e.GetPushedDate()),
}, },
[]handler.Condition{ []handler.Condition{
handler.NewCond(MilestoneColumnInstanceID, event.Aggregate().InstanceID), handler.NewCond(MilestoneColumnInstanceID, event.Aggregate().InstanceID),
@@ -246,58 +99,3 @@ func (p *milestoneProjection) reduceMilestonePushed(event eventstore.Event) (*ha
}, },
), nil ), nil
} }
func (p *milestoneProjection) reduceReachedIfUserEventFunc(msType milestone.Type) func(event eventstore.Event) (*handler.Statement, error) {
return func(event eventstore.Event) (*handler.Statement, error) {
if p.isSystemEvent(event) {
return handler.NewNoOpStatement(event), nil
}
return p.reduceReachedFunc(msType)(event)
}
}
func (p *milestoneProjection) reduceReachedFunc(msType milestone.Type) func(event eventstore.Event) (*handler.Statement, error) {
return func(event eventstore.Event) (*handler.Statement, error) {
return handler.NewUpdateStatement(event, []handler.Column{
handler.NewCol(MilestoneColumnReachedDate, event.CreatedAt()),
},
[]handler.Condition{
handler.NewCond(MilestoneColumnInstanceID, event.Aggregate().InstanceID),
handler.NewCond(MilestoneColumnType, msType),
handler.NewIsNullCond(MilestoneColumnReachedDate),
}), nil
}
}
func (p *milestoneProjection) reduceAppConfigAdded(event eventstore.Event, clientID string) (*handler.Statement, error) {
if !p.isSystemEvent(event) {
return handler.NewNoOpStatement(event), nil
}
return handler.NewUpdateStatement(
event,
[]handler.Column{
handler.NewArrayAppendCol(MilestoneColumnIgnoreClientIDs, clientID),
},
[]handler.Condition{
handler.NewCond(MilestoneColumnInstanceID, event.Aggregate().InstanceID),
handler.NewCond(MilestoneColumnType, milestone.AuthenticationSucceededOnApplication),
handler.NewIsNullCond(MilestoneColumnReachedDate),
},
), nil
}
func (p *milestoneProjection) isSystemEvent(event eventstore.Event) bool {
if userId, err := strconv.Atoi(event.Creator()); err == nil && userId > 0 {
return false
}
// check if it is a hard coded event creator
for _, creator := range []string{"", "system", "OIDC", "LOGIN", "SYSTEM"} {
if creator == event.Creator() {
return true
}
}
_, ok := p.systemUsers[event.Creator()]
return ok
}

View File

@@ -4,13 +4,11 @@ import (
"testing" "testing"
"time" "time"
"github.com/zitadel/zitadel/internal/database" "github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler/v2" "github.com/zitadel/zitadel/internal/eventstore/handler/v2"
"github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/milestone" "github.com/zitadel/zitadel/internal/repository/milestone"
"github.com/zitadel/zitadel/internal/repository/oidcsession"
"github.com/zitadel/zitadel/internal/repository/project"
"github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/internal/zerrors"
) )
@@ -19,6 +17,8 @@ func TestMilestonesProjection_reduces(t *testing.T) {
event func(t *testing.T) eventstore.Event event func(t *testing.T) eventstore.Event
} }
now := time.Now() now := time.Now()
date, err := time.Parse(time.RFC3339, "2006-01-02T15:04:05Z")
require.NoError(t, err)
tests := []struct { tests := []struct {
name string name string
args args args args
@@ -29,292 +29,54 @@ func TestMilestonesProjection_reduces(t *testing.T) {
name: "reduceInstanceAdded", name: "reduceInstanceAdded",
args: args{ args: args{
event: getEvent(timedTestEvent( event: getEvent(timedTestEvent(
instance.InstanceAddedEventType, milestone.ReachedEventType,
instance.AggregateType, milestone.AggregateType,
[]byte(`{}`), []byte(`{"type": "instance_created"}`),
now, now,
), instance.InstanceAddedEventMapper), withVersion(milestone.AggregateVersion),
), milestone.ReachedEventMapper),
}, },
reduce: (&milestoneProjection{}).reduceInstanceAdded, reduce: (&milestoneProjection{}).reduceReached,
want: wantReduce{ want: wantReduce{
aggregateType: eventstore.AggregateType("instance"), aggregateType: eventstore.AggregateType("milestone"),
sequence: 15, sequence: 15,
executer: &testExecuter{ executer: &testExecuter{
executions: []execution{ executions: []execution{
{ {
expectedStmt: "INSERT INTO projections.milestones (instance_id, type, reached_date) VALUES ($1, $2, $3)", expectedStmt: "INSERT INTO projections.milestones2 (instance_id, type, reached_date) VALUES ($1, $2, $3)",
expectedArgs: []interface{}{ expectedArgs: []interface{}{
"instance-id", "instance-id",
milestone.InstanceCreated, milestone.InstanceCreated,
now, now,
}, },
}, },
{
expectedStmt: "INSERT INTO projections.milestones (instance_id, type) VALUES ($1, $2)",
expectedArgs: []interface{}{
"instance-id",
milestone.AuthenticationSucceededOnInstance,
},
},
{
expectedStmt: "INSERT INTO projections.milestones (instance_id, type) VALUES ($1, $2)",
expectedArgs: []interface{}{
"instance-id",
milestone.ProjectCreated,
},
},
{
expectedStmt: "INSERT INTO projections.milestones (instance_id, type) VALUES ($1, $2)",
expectedArgs: []interface{}{
"instance-id",
milestone.ApplicationCreated,
},
},
{
expectedStmt: "INSERT INTO projections.milestones (instance_id, type) VALUES ($1, $2)",
expectedArgs: []interface{}{
"instance-id",
milestone.AuthenticationSucceededOnApplication,
},
},
{
expectedStmt: "INSERT INTO projections.milestones (instance_id, type) VALUES ($1, $2)",
expectedArgs: []interface{}{
"instance-id",
milestone.InstanceDeleted,
},
},
}, },
}, },
}, },
}, },
{ {
name: "reduceInstancePrimaryDomainSet", name: "reduceInstanceAdded with reached date",
args: args{
event: getEvent(testEvent(
instance.InstanceDomainPrimarySetEventType,
instance.AggregateType,
[]byte(`{"domain": "my.domain"}`),
), instance.DomainPrimarySetEventMapper),
},
reduce: (&milestoneProjection{}).reduceInstanceDomainPrimarySet,
want: wantReduce{
aggregateType: eventstore.AggregateType("instance"),
sequence: 15,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.milestones SET primary_domain = $1 WHERE (instance_id = $2) AND (last_pushed_date IS NULL)",
expectedArgs: []interface{}{
"my.domain",
"instance-id",
},
},
},
},
},
},
{
name: "reduceProjectAdded",
args: args{ args: args{
event: getEvent(timedTestEvent( event: getEvent(timedTestEvent(
project.ProjectAddedType, milestone.ReachedEventType,
project.AggregateType, milestone.AggregateType,
[]byte(`{}`), []byte(`{"type": "instance_created", "reachedDate":"2006-01-02T15:04:05Z"}`),
now, now,
), project.ProjectAddedEventMapper), withVersion(milestone.AggregateVersion),
), milestone.ReachedEventMapper),
}, },
reduce: (&milestoneProjection{}).reduceProjectAdded, reduce: (&milestoneProjection{}).reduceReached,
want: wantReduce{ want: wantReduce{
aggregateType: eventstore.AggregateType("project"), aggregateType: eventstore.AggregateType("milestone"),
sequence: 15, sequence: 15,
executer: &testExecuter{ executer: &testExecuter{
executions: []execution{ executions: []execution{
{ {
expectedStmt: "UPDATE projections.milestones SET reached_date = $1 WHERE (instance_id = $2) AND (type = $3) AND (reached_date IS NULL)", expectedStmt: "INSERT INTO projections.milestones2 (instance_id, type, reached_date) VALUES ($1, $2, $3)",
expectedArgs: []interface{}{ expectedArgs: []interface{}{
now,
"instance-id", "instance-id",
milestone.ProjectCreated, milestone.InstanceCreated,
}, date,
},
},
},
},
},
{
name: "reduceApplicationAdded",
args: args{
event: getEvent(timedTestEvent(
project.ApplicationAddedType,
project.AggregateType,
[]byte(`{}`),
now,
), project.ApplicationAddedEventMapper),
},
reduce: (&milestoneProjection{}).reduceApplicationAdded,
want: wantReduce{
aggregateType: eventstore.AggregateType("project"),
sequence: 15,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.milestones SET reached_date = $1 WHERE (instance_id = $2) AND (type = $3) AND (reached_date IS NULL)",
expectedArgs: []interface{}{
now,
"instance-id",
milestone.ApplicationCreated,
},
},
},
},
},
},
{
name: "reduceOIDCConfigAdded user event",
args: args{
event: getEvent(testEvent(
project.OIDCConfigAddedType,
project.AggregateType,
[]byte(`{}`),
), project.OIDCConfigAddedEventMapper),
},
reduce: (&milestoneProjection{}).reduceOIDCConfigAdded,
want: wantReduce{
aggregateType: eventstore.AggregateType("project"),
sequence: 15,
executer: &testExecuter{},
},
},
{
name: "reduceOIDCConfigAdded system event",
args: args{
event: getEvent(toSystemEvent(testEvent(
project.OIDCConfigAddedType,
project.AggregateType,
[]byte(`{"clientId": "client-id"}`),
)), project.OIDCConfigAddedEventMapper),
},
reduce: (&milestoneProjection{}).reduceOIDCConfigAdded,
want: wantReduce{
aggregateType: eventstore.AggregateType("project"),
sequence: 15,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.milestones SET ignore_client_ids = array_append(ignore_client_ids, $1) WHERE (instance_id = $2) AND (type = $3) AND (reached_date IS NULL)",
expectedArgs: []interface{}{
"client-id",
"instance-id",
milestone.AuthenticationSucceededOnApplication,
},
},
},
},
},
},
{
name: "reduceAPIConfigAdded user event",
args: args{
event: getEvent(testEvent(
project.APIConfigAddedType,
project.AggregateType,
[]byte(`{}`),
), project.APIConfigAddedEventMapper),
},
reduce: (&milestoneProjection{}).reduceAPIConfigAdded,
want: wantReduce{
aggregateType: eventstore.AggregateType("project"),
sequence: 15,
executer: &testExecuter{},
},
},
{
name: "reduceAPIConfigAdded system event",
args: args{
event: getEvent(toSystemEvent(testEvent(
project.APIConfigAddedType,
project.AggregateType,
[]byte(`{"clientId": "client-id"}`),
)), project.APIConfigAddedEventMapper),
},
reduce: (&milestoneProjection{}).reduceAPIConfigAdded,
want: wantReduce{
aggregateType: eventstore.AggregateType("project"),
sequence: 15,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.milestones SET ignore_client_ids = array_append(ignore_client_ids, $1) WHERE (instance_id = $2) AND (type = $3) AND (reached_date IS NULL)",
expectedArgs: []interface{}{
"client-id",
"instance-id",
milestone.AuthenticationSucceededOnApplication,
},
},
},
},
},
},
{
name: "reduceOIDCSessionAdded",
args: args{
event: getEvent(timedTestEvent(
oidcsession.AddedType,
oidcsession.AggregateType,
[]byte(`{"clientID": "client-id"}`),
now,
), eventstore.GenericEventMapper[oidcsession.AddedEvent]),
},
reduce: (&milestoneProjection{}).reduceOIDCSessionAdded,
want: wantReduce{
aggregateType: eventstore.AggregateType("oidc_session"),
sequence: 15,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.milestones SET reached_date = $1 WHERE (instance_id = $2) AND (type = $3) AND (reached_date IS NULL)",
expectedArgs: []interface{}{
now,
"instance-id",
milestone.AuthenticationSucceededOnInstance,
},
},
{
expectedStmt: "UPDATE projections.milestones SET reached_date = $1 WHERE (instance_id = $2) AND (type = $3) AND (NOT (ignore_client_ids @> $4)) AND (reached_date IS NULL)",
expectedArgs: []interface{}{
now,
"instance-id",
milestone.AuthenticationSucceededOnApplication,
database.TextArray[string]{"client-id"},
},
},
},
},
},
},
{
name: "reduceInstanceRemoved",
args: args{
event: getEvent(timedTestEvent(
instance.InstanceRemovedEventType,
instance.AggregateType,
[]byte(`{}`),
now,
), instance.InstanceRemovedEventMapper),
},
reduce: (&milestoneProjection{}).reduceInstanceRemoved,
want: wantReduce{
aggregateType: eventstore.AggregateType("instance"),
sequence: 15,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.milestones SET reached_date = $1 WHERE (instance_id = $2) AND (type = $3) AND (reached_date IS NULL)",
expectedArgs: []interface{}{
now,
"instance-id",
milestone.InstanceDeleted,
}, },
}, },
}, },
@@ -327,18 +89,19 @@ func TestMilestonesProjection_reduces(t *testing.T) {
event: getEvent(timedTestEvent( event: getEvent(timedTestEvent(
milestone.PushedEventType, milestone.PushedEventType,
milestone.AggregateType, milestone.AggregateType,
[]byte(`{"type": "ProjectCreated"}`), []byte(`{"type": "project_created"}`),
now, now,
withVersion(milestone.AggregateVersion),
), milestone.PushedEventMapper), ), milestone.PushedEventMapper),
}, },
reduce: (&milestoneProjection{}).reduceMilestonePushed, reduce: (&milestoneProjection{}).reducePushed,
want: wantReduce{ want: wantReduce{
aggregateType: eventstore.AggregateType("milestone"), aggregateType: eventstore.AggregateType("milestone"),
sequence: 15, sequence: 15,
executer: &testExecuter{ executer: &testExecuter{
executions: []execution{ executions: []execution{
{ {
expectedStmt: "UPDATE projections.milestones SET last_pushed_date = $1 WHERE (instance_id = $2) AND (type = $3)", expectedStmt: "UPDATE projections.milestones2 SET last_pushed_date = $1 WHERE (instance_id = $2) AND (type = $3)",
expectedArgs: []interface{}{ expectedArgs: []interface{}{
now, now,
"instance-id", "instance-id",
@@ -349,23 +112,53 @@ func TestMilestonesProjection_reduces(t *testing.T) {
}, },
}, },
}, },
{
name: "reduceMilestonePushed normal milestone with pushed date",
args: args{
event: getEvent(timedTestEvent(
milestone.PushedEventType,
milestone.AggregateType,
[]byte(`{"type": "project_created", "pushedDate":"2006-01-02T15:04:05Z"}`),
now,
withVersion(milestone.AggregateVersion),
), milestone.PushedEventMapper),
},
reduce: (&milestoneProjection{}).reducePushed,
want: wantReduce{
aggregateType: eventstore.AggregateType("milestone"),
sequence: 15,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.milestones2 SET last_pushed_date = $1 WHERE (instance_id = $2) AND (type = $3)",
expectedArgs: []interface{}{
date,
"instance-id",
milestone.ProjectCreated,
},
},
},
},
},
},
{ {
name: "reduceMilestonePushed instance deleted milestone", name: "reduceMilestonePushed instance deleted milestone",
args: args{ args: args{
event: getEvent(testEvent( event: getEvent(testEvent(
milestone.PushedEventType, milestone.PushedEventType,
milestone.AggregateType, milestone.AggregateType,
[]byte(`{"type": "InstanceDeleted"}`), []byte(`{"type": "instance_deleted"}`),
withVersion(milestone.AggregateVersion),
), milestone.PushedEventMapper), ), milestone.PushedEventMapper),
}, },
reduce: (&milestoneProjection{}).reduceMilestonePushed, reduce: (&milestoneProjection{}).reducePushed,
want: wantReduce{ want: wantReduce{
aggregateType: eventstore.AggregateType("milestone"), aggregateType: eventstore.AggregateType("milestone"),
sequence: 15, sequence: 15,
executer: &testExecuter{ executer: &testExecuter{
executions: []execution{ executions: []execution{
{ {
expectedStmt: "DELETE FROM projections.milestones WHERE (instance_id = $1)", expectedStmt: "DELETE FROM projections.milestones2 WHERE (instance_id = $1)",
expectedArgs: []interface{}{ expectedArgs: []interface{}{
"instance-id", "instance-id",
}, },

View File

@@ -156,7 +156,7 @@ func Create(ctx context.Context, sqlClient *database.DB, es handler.EventStore,
DeviceAuthProjection = newDeviceAuthProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["device_auth"])) DeviceAuthProjection = newDeviceAuthProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["device_auth"]))
SessionProjection = newSessionProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["sessions"])) SessionProjection = newSessionProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["sessions"]))
AuthRequestProjection = newAuthRequestProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["auth_requests"])) AuthRequestProjection = newAuthRequestProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["auth_requests"]))
MilestoneProjection = newMilestoneProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["milestones"]), systemUsers) MilestoneProjection = newMilestoneProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["milestones"]))
QuotaProjection = newQuotaProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["quotas"])) QuotaProjection = newQuotaProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["quotas"]))
LimitsProjection = newLimitsProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["limits"])) LimitsProjection = newLimitsProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["limits"]))
RestrictionsProjection = newRestrictionsProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["restrictions"])) RestrictionsProjection = newRestrictionsProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["restrictions"]))

View File

@@ -9,20 +9,23 @@ import (
const ( const (
AggregateType = "milestone" AggregateType = "milestone"
AggregateVersion = "v1" AggregateVersion = "v2"
) )
type Aggregate struct { type Aggregate struct {
eventstore.Aggregate eventstore.Aggregate
} }
func NewAggregate(ctx context.Context, id string) *Aggregate { func NewAggregate(ctx context.Context) *Aggregate {
instanceID := authz.GetInstance(ctx).InstanceID() return NewInstanceAggregate(authz.GetInstance(ctx).InstanceID())
}
func NewInstanceAggregate(instanceID string) *Aggregate {
return &Aggregate{ return &Aggregate{
Aggregate: eventstore.Aggregate{ Aggregate: eventstore.Aggregate{
Type: AggregateType, Type: AggregateType,
Version: AggregateVersion, Version: AggregateVersion,
ID: id, ID: instanceID,
ResourceOwner: instanceID, ResourceOwner: instanceID,
InstanceID: instanceID, InstanceID: instanceID,
}, },

View File

@@ -2,23 +2,88 @@ package milestone
import ( import (
"context" "context"
"time"
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
) )
//go:generate enumer -type Type -json -linecomment -transform=snake
type Type int
const ( const (
eventTypePrefix = eventstore.EventType("milestone.") InstanceCreated Type = iota
PushedEventType = eventTypePrefix + "pushed" AuthenticationSucceededOnInstance
ProjectCreated
ApplicationCreated
AuthenticationSucceededOnApplication
InstanceDeleted
) )
var _ eventstore.Command = (*PushedEvent)(nil) const (
eventTypePrefix = "milestone."
ReachedEventType = eventTypePrefix + "reached"
PushedEventType = eventTypePrefix + "pushed"
)
type ReachedEvent struct {
*eventstore.BaseEvent `json:"-"`
MilestoneType Type `json:"type"`
ReachedDate *time.Time `json:"reachedDate,omitempty"` // Defaults to [eventstore.BaseEvent.Creation] when empty
}
// Payload implements eventstore.Command.
func (e *ReachedEvent) Payload() any {
return e
}
func (e *ReachedEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
return nil
}
func (e *ReachedEvent) SetBaseEvent(b *eventstore.BaseEvent) {
e.BaseEvent = b
}
func (e *ReachedEvent) GetReachedDate() time.Time {
if e.ReachedDate != nil {
return *e.ReachedDate
}
return e.Creation
}
func NewReachedEvent(
ctx context.Context,
aggregate *Aggregate,
typ Type,
) *ReachedEvent {
return NewReachedEventWithDate(ctx, aggregate, typ, nil)
}
// NewReachedEventWithDate creates a [ReachedEvent] with a fixed Reached Date.
func NewReachedEventWithDate(
ctx context.Context,
aggregate *Aggregate,
typ Type,
reachedDate *time.Time,
) *ReachedEvent {
return &ReachedEvent{
BaseEvent: eventstore.NewBaseEventForPush(
ctx,
&aggregate.Aggregate,
ReachedEventType,
),
MilestoneType: typ,
ReachedDate: reachedDate,
}
}
type PushedEvent struct { type PushedEvent struct {
*eventstore.BaseEvent `json:"-"` *eventstore.BaseEvent `json:"-"`
MilestoneType Type `json:"type"` MilestoneType Type `json:"type"`
ExternalDomain string `json:"externalDomain"` ExternalDomain string `json:"externalDomain"`
PrimaryDomain string `json:"primaryDomain"` PrimaryDomain string `json:"primaryDomain"`
Endpoints []string `json:"endpoints"` Endpoints []string `json:"endpoints"`
PushedDate *time.Time `json:"pushedDate,omitempty"` // Defaults to [eventstore.BaseEvent.Creation] when empty
} }
// Payload implements eventstore.Command. // Payload implements eventstore.Command.
@@ -34,14 +99,31 @@ func (p *PushedEvent) SetBaseEvent(b *eventstore.BaseEvent) {
p.BaseEvent = b p.BaseEvent = b
} }
var PushedEventMapper = eventstore.GenericEventMapper[PushedEvent] func (e *PushedEvent) GetPushedDate() time.Time {
if e.PushedDate != nil {
return *e.PushedDate
}
return e.Creation
}
func NewPushedEvent( func NewPushedEvent(
ctx context.Context, ctx context.Context,
aggregate *Aggregate, aggregate *Aggregate,
msType Type, typ Type,
endpoints []string, endpoints []string,
externalDomain, primaryDomain string, externalDomain string,
) *PushedEvent {
return NewPushedEventWithDate(ctx, aggregate, typ, endpoints, externalDomain, nil)
}
// NewPushedEventWithDate creates a [PushedEvent] with a fixed Pushed Date.
func NewPushedEventWithDate(
ctx context.Context,
aggregate *Aggregate,
typ Type,
endpoints []string,
externalDomain string,
pushedDate *time.Time,
) *PushedEvent { ) *PushedEvent {
return &PushedEvent{ return &PushedEvent{
BaseEvent: eventstore.NewBaseEventForPush( BaseEvent: eventstore.NewBaseEventForPush(
@@ -49,9 +131,9 @@ func NewPushedEvent(
&aggregate.Aggregate, &aggregate.Aggregate,
PushedEventType, PushedEventType,
), ),
MilestoneType: msType, MilestoneType: typ,
Endpoints: endpoints, Endpoints: endpoints,
ExternalDomain: externalDomain, ExternalDomain: externalDomain,
PrimaryDomain: primaryDomain, PushedDate: pushedDate,
} }
} }

View File

@@ -4,6 +4,12 @@ import (
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
) )
var (
ReachedEventMapper = eventstore.GenericEventMapper[ReachedEvent]
PushedEventMapper = eventstore.GenericEventMapper[PushedEvent]
)
func init() { func init() {
eventstore.RegisterFilterEventMapper(AggregateType, ReachedEventType, ReachedEventMapper)
eventstore.RegisterFilterEventMapper(AggregateType, PushedEventType, PushedEventMapper) eventstore.RegisterFilterEventMapper(AggregateType, PushedEventType, PushedEventMapper)
} }

View File

@@ -1,59 +0,0 @@
//go:generate stringer -type Type
package milestone
import (
"fmt"
"strings"
)
type Type int
const (
unknown Type = iota
InstanceCreated
AuthenticationSucceededOnInstance
ProjectCreated
ApplicationCreated
AuthenticationSucceededOnApplication
InstanceDeleted
typesCount
)
func AllTypes() []Type {
types := make([]Type, typesCount-1)
for i := Type(1); i < typesCount; i++ {
types[i-1] = i
}
return types
}
func (t *Type) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf(`"%s"`, t.String())), nil
}
func (t *Type) UnmarshalJSON(data []byte) error {
*t = typeFromString(strings.Trim(string(data), `"`))
return nil
}
func typeFromString(t string) Type {
switch t {
case InstanceCreated.String():
return InstanceCreated
case AuthenticationSucceededOnInstance.String():
return AuthenticationSucceededOnInstance
case ProjectCreated.String():
return ProjectCreated
case ApplicationCreated.String():
return ApplicationCreated
case AuthenticationSucceededOnApplication.String():
return AuthenticationSucceededOnApplication
case InstanceDeleted.String():
return InstanceDeleted
default:
return unknown
}
}

View File

@@ -0,0 +1,112 @@
// Code generated by "enumer -type Type -json -linecomment -transform=snake"; DO NOT EDIT.
package milestone
import (
"encoding/json"
"fmt"
"strings"
)
const _TypeName = "instance_createdauthentication_succeeded_on_instanceproject_createdapplication_createdauthentication_succeeded_on_applicationinstance_deleted"
var _TypeIndex = [...]uint8{0, 16, 52, 67, 86, 125, 141}
const _TypeLowerName = "instance_createdauthentication_succeeded_on_instanceproject_createdapplication_createdauthentication_succeeded_on_applicationinstance_deleted"
func (i Type) String() string {
if i < 0 || i >= Type(len(_TypeIndex)-1) {
return fmt.Sprintf("Type(%d)", i)
}
return _TypeName[_TypeIndex[i]:_TypeIndex[i+1]]
}
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
func _TypeNoOp() {
var x [1]struct{}
_ = x[InstanceCreated-(0)]
_ = x[AuthenticationSucceededOnInstance-(1)]
_ = x[ProjectCreated-(2)]
_ = x[ApplicationCreated-(3)]
_ = x[AuthenticationSucceededOnApplication-(4)]
_ = x[InstanceDeleted-(5)]
}
var _TypeValues = []Type{InstanceCreated, AuthenticationSucceededOnInstance, ProjectCreated, ApplicationCreated, AuthenticationSucceededOnApplication, InstanceDeleted}
var _TypeNameToValueMap = map[string]Type{
_TypeName[0:16]: InstanceCreated,
_TypeLowerName[0:16]: InstanceCreated,
_TypeName[16:52]: AuthenticationSucceededOnInstance,
_TypeLowerName[16:52]: AuthenticationSucceededOnInstance,
_TypeName[52:67]: ProjectCreated,
_TypeLowerName[52:67]: ProjectCreated,
_TypeName[67:86]: ApplicationCreated,
_TypeLowerName[67:86]: ApplicationCreated,
_TypeName[86:125]: AuthenticationSucceededOnApplication,
_TypeLowerName[86:125]: AuthenticationSucceededOnApplication,
_TypeName[125:141]: InstanceDeleted,
_TypeLowerName[125:141]: InstanceDeleted,
}
var _TypeNames = []string{
_TypeName[0:16],
_TypeName[16:52],
_TypeName[52:67],
_TypeName[67:86],
_TypeName[86:125],
_TypeName[125:141],
}
// TypeString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func TypeString(s string) (Type, error) {
if val, ok := _TypeNameToValueMap[s]; ok {
return val, nil
}
if val, ok := _TypeNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to Type values", s)
}
// TypeValues returns all values of the enum
func TypeValues() []Type {
return _TypeValues
}
// TypeStrings returns a slice of all String values of the enum
func TypeStrings() []string {
strs := make([]string, len(_TypeNames))
copy(strs, _TypeNames)
return strs
}
// IsAType returns "true" if the value is listed in the enum definition. "false" otherwise
func (i Type) IsAType() bool {
for _, v := range _TypeValues {
if i == v {
return true
}
}
return false
}
// MarshalJSON implements the json.Marshaler interface for Type
func (i Type) MarshalJSON() ([]byte, error) {
return json.Marshal(i.String())
}
// UnmarshalJSON implements the json.Unmarshaler interface for Type
func (i *Type) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return fmt.Errorf("Type should be a string, got %s", data)
}
var err error
*i, err = TypeString(s)
return err
}

View File

@@ -1,30 +0,0 @@
// Code generated by "stringer -type Type"; DO NOT EDIT.
package milestone
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[unknown-0]
_ = x[InstanceCreated-1]
_ = x[AuthenticationSucceededOnInstance-2]
_ = x[ProjectCreated-3]
_ = x[ApplicationCreated-4]
_ = x[AuthenticationSucceededOnApplication-5]
_ = x[InstanceDeleted-6]
_ = x[typesCount-7]
}
const _Type_name = "unknownInstanceCreatedAuthenticationSucceededOnInstanceProjectCreatedApplicationCreatedAuthenticationSucceededOnApplicationInstanceDeletedtypesCount"
var _Type_index = [...]uint8{0, 7, 22, 55, 69, 87, 123, 138, 148}
func (i Type) String() string {
if i < 0 || i >= Type(len(_Type_index)-1) {
return "Type(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _Type_name[_Type_index[i]:_Type_index[i+1]]
}