perf(milestones): refactor (#8788)
Some checks are pending
ZITADEL CI/CD / core (push) Waiting to run
ZITADEL CI/CD / console (push) Waiting to run
ZITADEL CI/CD / version (push) Waiting to run
ZITADEL CI/CD / compile (push) Blocked by required conditions
ZITADEL CI/CD / core-unit-test (push) Blocked by required conditions
ZITADEL CI/CD / core-integration-test (push) Blocked by required conditions
ZITADEL CI/CD / lint (push) Blocked by required conditions
ZITADEL CI/CD / container (push) Blocked by required conditions
ZITADEL CI/CD / e2e (push) Blocked by required conditions
ZITADEL CI/CD / release (push) Blocked by required conditions
Code Scanning / CodeQL-Build (go) (push) Waiting to run
Code Scanning / CodeQL-Build (javascript) (push) Waiting to run

# Which Problems Are Solved

Milestones used existing events from a number of aggregates. OIDC
session is one of them. We noticed in load-tests that the reduction of
the oidc_session.added event into the milestone projection is a costly
business with payload based conditionals. A milestone is reached once,
but even then we remain subscribed to the OIDC events. This requires the
projections.current_states to be updated continuously.


# How the Problems Are Solved

The milestone creation is refactored to use dedicated events instead.
The command side decides when a milestone is reached and creates the
reached event once for each milestone when required.

# Additional Changes

In order to prevent reached milestones being created twice, a migration
script is provided. When the old `projections.milestones` table exist,
the state is read from there and `v2` milestone aggregate events are
created, with the original reached and pushed dates.

# Additional Context

- Closes https://github.com/zitadel/zitadel/issues/8800
This commit is contained in:
Tim Möhlmann 2024-10-28 09:29:34 +01:00 committed by GitHub
parent 54f1c0bc50
commit 32bad3feb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 1612 additions and 756 deletions

View File

@ -163,6 +163,7 @@ func projections(
}
commands, err := command.StartCommands(
es,
config.Caches,
config.SystemDefaults,
config.InternalAuthZ.RolePermissionMappings,
staticStorage,

View File

@ -65,6 +65,7 @@ func (mig *FirstInstance) Execute(ctx context.Context, _ eventstore.Event) error
}
cmd, err := command.StartCommands(mig.es,
nil,
mig.defaults,
mig.zitadelRoles,
nil,

118
cmd/setup/36.go Normal file
View File

@ -0,0 +1,118 @@
package setup
import (
"context"
_ "embed"
"errors"
"fmt"
"slices"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/milestone"
)
var (
//go:embed 36.sql
getProjectedMilestones string
)
type FillV2Milestones struct {
dbClient *database.DB
eventstore *eventstore.Eventstore
}
type instanceMilestone struct {
Type milestone.Type
Reached time.Time
Pushed *time.Time
}
func (mig *FillV2Milestones) Execute(ctx context.Context, _ eventstore.Event) error {
im, err := mig.getProjectedMilestones(ctx)
if err != nil {
return err
}
return mig.pushEventsByInstance(ctx, im)
}
func (mig *FillV2Milestones) getProjectedMilestones(ctx context.Context) (map[string][]instanceMilestone, error) {
type row struct {
InstanceID string
Type milestone.Type
Reached time.Time
Pushed *time.Time
}
rows, _ := mig.dbClient.Pool.Query(ctx, getProjectedMilestones)
scanned, err := pgx.CollectRows(rows, pgx.RowToStructByPos[row])
var pgError *pgconn.PgError
// catch ERROR: relation "projections.milestones" does not exist
if errors.As(err, &pgError) && pgError.SQLState() == "42P01" {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("milestones get: %w", err)
}
milestoneMap := make(map[string][]instanceMilestone)
for _, s := range scanned {
milestoneMap[s.InstanceID] = append(milestoneMap[s.InstanceID], instanceMilestone{
Type: s.Type,
Reached: s.Reached,
Pushed: s.Pushed,
})
}
return milestoneMap, nil
}
// pushEventsByInstance creates the v2 milestone events by instance.
// This prevents we will try to push 6*N(instance) events in one push.
func (mig *FillV2Milestones) pushEventsByInstance(ctx context.Context, milestoneMap map[string][]instanceMilestone) error {
// keep a deterministic order by instance ID.
order := make([]string, 0, len(milestoneMap))
for k := range milestoneMap {
order = append(order, k)
}
slices.Sort(order)
for _, instanceID := range order {
logging.WithFields("instance_id", instanceID, "migration", mig.String()).Info("filter existing milestone events")
// because each Push runs in a separate TX, we need to make sure that events
// from a partially executed migration are pushed again.
model := command.NewMilestonesReachedWriteModel(instanceID)
if err := mig.eventstore.FilterToQueryReducer(ctx, model); err != nil {
return fmt.Errorf("milestones filter: %w", err)
}
if model.InstanceCreated {
logging.WithFields("instance_id", instanceID, "migration", mig.String()).Info("milestone events already migrated")
continue // This instance was migrated, skip
}
logging.WithFields("instance_id", instanceID, "migration", mig.String()).Info("push milestone events")
aggregate := milestone.NewInstanceAggregate(instanceID)
cmds := make([]eventstore.Command, 0, len(milestoneMap[instanceID])*2)
for _, m := range milestoneMap[instanceID] {
cmds = append(cmds, milestone.NewReachedEventWithDate(ctx, aggregate, m.Type, &m.Reached))
if m.Pushed != nil {
cmds = append(cmds, milestone.NewPushedEventWithDate(ctx, aggregate, m.Type, nil, "", m.Pushed))
}
}
if _, err := mig.eventstore.Push(ctx, cmds...); err != nil {
return fmt.Errorf("milestones push: %w", err)
}
}
return nil
}
func (mig *FillV2Milestones) String() string {
return "36_fill_v2_milestones"
}

4
cmd/setup/36.sql Normal file
View File

@ -0,0 +1,4 @@
SELECT instance_id, type, reached_date, last_pushed_date
FROM projections.milestones
WHERE reached_date IS NOT NULL
ORDER BY instance_id, reached_date;

View File

@ -122,6 +122,7 @@ type Steps struct {
s33SMSConfigs3TwilioAddVerifyServiceSid *SMSConfigs3TwilioAddVerifyServiceSid
s34AddCacheSchema *AddCacheSchema
s35AddPositionToIndexEsWm *AddPositionToIndexEsWm
s36FillV2Milestones *FillV2Milestones
}
func MustNewSteps(v *viper.Viper) *Steps {

View File

@ -33,6 +33,7 @@ func (mig *externalConfigChange) Check(lastRun map[string]interface{}) bool {
func (mig *externalConfigChange) Execute(ctx context.Context, _ eventstore.Event) error {
cmd, err := command.StartCommands(
mig.es,
nil,
mig.defaults,
nil,
nil,

View File

@ -165,6 +165,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
steps.s33SMSConfigs3TwilioAddVerifyServiceSid = &SMSConfigs3TwilioAddVerifyServiceSid{dbClient: esPusherDBClient}
steps.s34AddCacheSchema = &AddCacheSchema{dbClient: queryDBClient}
steps.s35AddPositionToIndexEsWm = &AddPositionToIndexEsWm{dbClient: esPusherDBClient}
steps.s36FillV2Milestones = &FillV2Milestones{dbClient: queryDBClient, eventstore: eventstoreClient}
err = projection.Create(ctx, projectionDBClient, eventstoreClient, config.Projections, nil, nil, nil)
logging.OnError(err).Fatal("unable to start projections")
@ -209,6 +210,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
steps.s30FillFieldsForOrgDomainVerified,
steps.s34AddCacheSchema,
steps.s35AddPositionToIndexEsWm,
steps.s36FillV2Milestones,
} {
mustExecuteMigration(ctx, eventstoreClient, step, "migration failed")
}
@ -390,6 +392,7 @@ func initProjections(
}
commands, err := command.StartCommands(
eventstoreClient,
config.Caches,
config.SystemDefaults,
config.InternalAuthZ.RolePermissionMappings,
staticStorage,

View File

@ -224,6 +224,7 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server
}
commands, err := command.StartCommands(
eventstoreClient,
config.Caches,
config.SystemDefaults,
config.InternalAuthZ.RolePermissionMappings,
storage,

View File

@ -114,7 +114,15 @@ func WithConsole(ctx context.Context, projectID, appID string) context.Context {
i.projectID = projectID
i.appID = appID
//i.clientID = clientID
return context.WithValue(ctx, instanceKey, i)
}
func WithConsoleClientID(ctx context.Context, clientID string) context.Context {
i, ok := ctx.Value(instanceKey).(*instance)
if !ok {
i = new(instance)
}
i.clientID = clientID
return context.WithValue(ctx, instanceKey, i)
}

View File

@ -6,6 +6,7 @@ import (
"time"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database/postgres"
)
@ -77,7 +78,8 @@ type CachesConfig struct {
Postgres PostgresConnectorConfig
// Redis redis.Config?
}
Instance *CacheConfig
Instance *CacheConfig
Milestones *CacheConfig
}
type CacheConfig struct {

82
internal/command/cache.go Normal file
View File

@ -0,0 +1,82 @@
package command
import (
"context"
"fmt"
"strings"
"github.com/zitadel/zitadel/internal/cache"
"github.com/zitadel/zitadel/internal/cache/gomap"
"github.com/zitadel/zitadel/internal/cache/noop"
"github.com/zitadel/zitadel/internal/cache/pg"
"github.com/zitadel/zitadel/internal/database"
)
type Caches struct {
connectors *cacheConnectors
milestones cache.Cache[milestoneIndex, string, *MilestonesReached]
}
func startCaches(background context.Context, conf *cache.CachesConfig, client *database.DB) (_ *Caches, err error) {
caches := &Caches{
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
}
if conf == nil {
return caches, nil
}
caches.connectors, err = startCacheConnectors(background, conf, client)
if err != nil {
return nil, err
}
caches.milestones, err = startCache[milestoneIndex, string, *MilestonesReached](background, []milestoneIndex{milestoneIndexInstanceID}, "milestones", conf.Instance, caches.connectors)
if err != nil {
return nil, err
}
return caches, nil
}
type cacheConnectors struct {
memory *cache.AutoPruneConfig
postgres *pgxPoolCacheConnector
}
type pgxPoolCacheConnector struct {
*cache.AutoPruneConfig
client *database.DB
}
func startCacheConnectors(_ context.Context, conf *cache.CachesConfig, client *database.DB) (_ *cacheConnectors, err error) {
connectors := new(cacheConnectors)
if conf.Connectors.Memory.Enabled {
connectors.memory = &conf.Connectors.Memory.AutoPrune
}
if conf.Connectors.Postgres.Enabled {
connectors.postgres = &pgxPoolCacheConnector{
AutoPruneConfig: &conf.Connectors.Postgres.AutoPrune,
client: client,
}
}
return connectors, nil
}
func startCache[I ~int, K ~string, V cache.Entry[I, K]](background context.Context, indices []I, name string, conf *cache.CacheConfig, connectors *cacheConnectors) (cache.Cache[I, K, V], error) {
if conf == nil || conf.Connector == "" {
return noop.NewCache[I, K, V](), nil
}
if strings.EqualFold(conf.Connector, "memory") && connectors.memory != nil {
c := gomap.NewCache[I, K, V](background, indices, *conf)
connectors.memory.StartAutoPrune(background, c, name)
return c, nil
}
if strings.EqualFold(conf.Connector, "postgres") && connectors.postgres != nil {
client := connectors.postgres.client
c, err := pg.NewCache[I, K, V](background, name, *conf, indices, client.Pool, client.Type())
if err != nil {
return nil, fmt.Errorf("query start cache: %w", err)
}
connectors.postgres.StartAutoPrune(background, c, name)
return c, nil
}
return nil, fmt.Errorf("cache connector %q not enabled", conf.Connector)
}

View File

@ -18,6 +18,7 @@ import (
"github.com/zitadel/zitadel/internal/api/authz"
api_http "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/cache"
"github.com/zitadel/zitadel/internal/command/preparation"
sd "github.com/zitadel/zitadel/internal/config/systemdefaults"
"github.com/zitadel/zitadel/internal/crypto"
@ -88,10 +89,17 @@ type Commands struct {
EventGroupExisting func(group string) bool
GenerateDomain func(instanceName, domain string) (string, error)
caches *Caches
// Store instance IDs where all milestones are reached (except InstanceDeleted).
// These instance's milestones never need to be invalidated,
// so the query and cache overhead can completely eliminated.
milestonesCompleted sync.Map
}
func StartCommands(
es *eventstore.Eventstore,
cachesConfig *cache.CachesConfig,
defaults sd.SystemDefaults,
zitadelRoles []authz.RoleMapping,
staticStore static.Storage,
@ -123,6 +131,10 @@ func StartCommands(
if err != nil {
return nil, fmt.Errorf("password hasher: %w", err)
}
caches, err := startCaches(context.TODO(), cachesConfig, es.Client())
if err != nil {
return nil, fmt.Errorf("caches: %w", err)
}
repo = &Commands{
eventstore: es,
static: staticStore,
@ -176,6 +188,7 @@ func StartCommands(
},
},
GenerateDomain: domain.NewGeneratedInstanceDomain,
caches: caches,
}
if defaultSecretGenerators != nil && defaultSecretGenerators.ClientSecret != nil {

View File

@ -6,6 +6,7 @@ import (
"golang.org/x/text/language"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/command/preparation"
@ -17,6 +18,7 @@ import (
"github.com/zitadel/zitadel/internal/notification/channels/smtp"
"github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/limits"
"github.com/zitadel/zitadel/internal/repository/milestone"
"github.com/zitadel/zitadel/internal/repository/org"
"github.com/zitadel/zitadel/internal/repository/project"
"github.com/zitadel/zitadel/internal/repository/quota"
@ -292,7 +294,7 @@ func setUpInstance(ctx context.Context, c *Commands, setup *InstanceSetup) (vali
setupFeatures(&validations, setup.Features, setup.zitadel.instanceID)
setupLimits(c, &validations, limits.NewAggregate(setup.zitadel.limitsID, setup.zitadel.instanceID), setup.Limits)
setupRestrictions(c, &validations, restrictions.NewAggregate(setup.zitadel.restrictionsID, setup.zitadel.instanceID, setup.zitadel.instanceID), setup.Restrictions)
setupInstanceCreatedMilestone(&validations, setup.zitadel.instanceID)
return validations, pat, machineKey, nil
}
@ -890,7 +892,8 @@ func (c *Commands) RemoveInstance(ctx context.Context, id string) (*domain.Objec
if err != nil {
return nil, err
}
err = c.caches.milestones.Invalidate(ctx, milestoneIndexInstanceID, id)
logging.OnError(err).Error("milestone invalidate")
return &domain.ObjectDetails{
Sequence: events[len(events)-1].Sequence(),
EventDate: events[len(events)-1].CreatedAt(),
@ -908,10 +911,16 @@ func (c *Commands) prepareRemoveInstance(a *instance.Aggregate) preparation.Vali
if !writeModel.State.Exists() {
return nil, zerrors.ThrowNotFound(err, "COMMA-AE3GS", "Errors.Instance.NotFound")
}
return []eventstore.Command{instance.NewInstanceRemovedEvent(ctx,
&a.Aggregate,
writeModel.Name,
writeModel.Domains)},
milestoneAggregate := milestone.NewInstanceAggregate(a.ID)
return []eventstore.Command{
instance.NewInstanceRemovedEvent(ctx,
&a.Aggregate,
writeModel.Name,
writeModel.Domains),
milestone.NewReachedEvent(ctx,
milestoneAggregate,
milestone.InstanceDeleted),
},
nil
}, nil
}

View File

@ -13,6 +13,7 @@ import (
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/cache/noop"
"github.com/zitadel/zitadel/internal/command/preparation"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
@ -20,6 +21,7 @@ import (
"github.com/zitadel/zitadel/internal/id"
id_mock "github.com/zitadel/zitadel/internal/id/mock"
"github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/milestone"
"github.com/zitadel/zitadel/internal/repository/org"
"github.com/zitadel/zitadel/internal/repository/project"
"github.com/zitadel/zitadel/internal/repository/user"
@ -372,6 +374,7 @@ func setupInstanceEvents(ctx context.Context, instanceID, orgID, projectID, appI
setupInstanceElementsEvents(ctx, instanceID, instanceName, defaultLanguage),
orgEvents(ctx, instanceID, orgID, orgName, projectID, domain, externalSecure, true, true),
generatedDomainEvents(ctx, instanceID, orgID, projectID, appID, domain),
instanceCreatedMilestoneEvent(ctx, instanceID),
)
}
@ -401,6 +404,12 @@ func generatedDomainEvents(ctx context.Context, instanceID, orgID, projectID, ap
}
}
func instanceCreatedMilestoneEvent(ctx context.Context, instanceID string) []eventstore.Command {
return []eventstore.Command{
milestone.NewReachedEvent(ctx, milestone.NewInstanceAggregate(instanceID), milestone.InstanceCreated),
}
}
func generatedDomainFilters(instanceID, orgID, projectID, appID, generatedDomain string) []expect {
return []expect{
expectFilter(),
@ -1378,7 +1387,7 @@ func TestCommandSide_UpdateInstance(t *testing.T) {
func TestCommandSide_RemoveInstance(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
eventstore func(t *testing.T) *eventstore.Eventstore
}
type args struct {
ctx context.Context
@ -1397,8 +1406,7 @@ func TestCommandSide_RemoveInstance(t *testing.T) {
{
name: "instance not existing, not found error",
fields: fields{
eventstore: eventstoreExpect(
t,
eventstore: expectEventstore(
expectFilter(),
),
},
@ -1413,8 +1421,7 @@ func TestCommandSide_RemoveInstance(t *testing.T) {
{
name: "instance removed, not found error",
fields: fields{
eventstore: eventstoreExpect(
t,
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
instance.NewInstanceAddedEvent(
@ -1444,8 +1451,7 @@ func TestCommandSide_RemoveInstance(t *testing.T) {
{
name: "instance remove, ok",
fields: fields{
eventstore: eventstoreExpect(
t,
eventstore: expectEventstore(
expectFilter(
eventFromEventPusherWithInstanceID(
"INSTANCE",
@ -1480,6 +1486,10 @@ func TestCommandSide_RemoveInstance(t *testing.T) {
"custom.domain",
},
),
milestone.NewReachedEvent(context.Background(),
milestone.NewInstanceAggregate("INSTANCE"),
milestone.InstanceDeleted,
),
),
),
},
@ -1497,7 +1507,10 @@ func TestCommandSide_RemoveInstance(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &Commands{
eventstore: tt.fields.eventstore,
eventstore: tt.fields.eventstore(t),
caches: &Caches{
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
},
}
got, err := r.RemoveInstance(tt.args.ctx, tt.args.instanceID)
if tt.res.err == nil {

View File

@ -3,20 +3,176 @@ package command
import (
"context"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/command/preparation"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/milestone"
)
// MilestonePushed writes a new milestone.PushedEvent with a new milestone.Aggregate to the eventstore
type milestoneIndex int
const (
milestoneIndexInstanceID milestoneIndex = iota
)
type MilestonesReached struct {
InstanceID string
InstanceCreated bool
AuthenticationSucceededOnInstance bool
ProjectCreated bool
ApplicationCreated bool
AuthenticationSucceededOnApplication bool
InstanceDeleted bool
}
// complete returns true if all milestones except InstanceDeleted are reached.
func (m *MilestonesReached) complete() bool {
return m.InstanceCreated &&
m.AuthenticationSucceededOnInstance &&
m.ProjectCreated &&
m.ApplicationCreated &&
m.AuthenticationSucceededOnApplication
}
// GetMilestonesReached finds the milestone state for the current instance.
func (c *Commands) GetMilestonesReached(ctx context.Context) (*MilestonesReached, error) {
milestones, ok := c.getCachedMilestonesReached(ctx)
if ok {
return milestones, nil
}
model := NewMilestonesReachedWriteModel(authz.GetInstance(ctx).InstanceID())
if err := c.eventstore.FilterToQueryReducer(ctx, model); err != nil {
return nil, err
}
milestones = &model.MilestonesReached
c.setCachedMilestonesReached(ctx, milestones)
return milestones, nil
}
// getCachedMilestonesReached checks for milestone completeness on an instance and returns a filled
// [MilestonesReached] object.
// Otherwise it looks for the object in the milestone cache.
func (c *Commands) getCachedMilestonesReached(ctx context.Context) (*MilestonesReached, bool) {
instanceID := authz.GetInstance(ctx).InstanceID()
if _, ok := c.milestonesCompleted.Load(instanceID); ok {
return &MilestonesReached{
InstanceID: instanceID,
InstanceCreated: true,
AuthenticationSucceededOnInstance: true,
ProjectCreated: true,
ApplicationCreated: true,
AuthenticationSucceededOnApplication: true,
InstanceDeleted: false,
}, ok
}
return c.caches.milestones.Get(ctx, milestoneIndexInstanceID, instanceID)
}
// setCachedMilestonesReached stores the current milestones state in the milestones cache.
// If the milestones are complete, the instance ID is stored in milestonesCompleted instead.
func (c *Commands) setCachedMilestonesReached(ctx context.Context, milestones *MilestonesReached) {
if milestones.complete() {
c.milestonesCompleted.Store(milestones.InstanceID, struct{}{})
return
}
c.caches.milestones.Set(ctx, milestones)
}
// Keys implements cache.Entry
func (c *MilestonesReached) Keys(i milestoneIndex) []string {
if i == milestoneIndexInstanceID {
return []string{c.InstanceID}
}
return nil
}
// MilestonePushed writes a new milestone.PushedEvent with the milestone.Aggregate to the eventstore
func (c *Commands) MilestonePushed(
ctx context.Context,
instanceID string,
msType milestone.Type,
endpoints []string,
primaryDomain string,
) error {
id, err := c.idGenerator.Next()
if err != nil {
return err
}
_, err = c.eventstore.Push(ctx, milestone.NewPushedEvent(ctx, milestone.NewAggregate(ctx, id), msType, endpoints, primaryDomain, c.externalDomain))
_, err := c.eventstore.Push(ctx, milestone.NewPushedEvent(ctx, milestone.NewInstanceAggregate(instanceID), msType, endpoints, c.externalDomain))
return err
}
func setupInstanceCreatedMilestone(validations *[]preparation.Validation, instanceID string) {
*validations = append(*validations, func() (preparation.CreateCommands, error) {
return func(ctx context.Context, _ preparation.FilterToQueryReducer) ([]eventstore.Command, error) {
return []eventstore.Command{
milestone.NewReachedEvent(ctx, milestone.NewInstanceAggregate(instanceID), milestone.InstanceCreated),
}, nil
}, nil
})
}
func (s *OIDCSessionEvents) SetMilestones(ctx context.Context, clientID string, isHuman bool) (postCommit func(ctx context.Context), err error) {
postCommit = func(ctx context.Context) {}
milestones, err := s.commands.GetMilestonesReached(ctx)
if err != nil {
return postCommit, err
}
instance := authz.GetInstance(ctx)
aggregate := milestone.NewAggregate(ctx)
var invalidate bool
if !milestones.AuthenticationSucceededOnInstance {
s.events = append(s.events, milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance))
invalidate = true
}
if !milestones.AuthenticationSucceededOnApplication && isHuman && clientID != instance.ConsoleClientID() {
s.events = append(s.events, milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication))
invalidate = true
}
if invalidate {
postCommit = s.commands.invalidateMilestoneCachePostCommit(instance.InstanceID())
}
return postCommit, nil
}
func (c *Commands) projectCreatedMilestone(ctx context.Context, cmds *[]eventstore.Command) (postCommit func(ctx context.Context), err error) {
postCommit = func(ctx context.Context) {}
if isSystemUser(ctx) {
return postCommit, nil
}
milestones, err := c.GetMilestonesReached(ctx)
if err != nil {
return postCommit, err
}
if milestones.ProjectCreated {
return postCommit, nil
}
aggregate := milestone.NewAggregate(ctx)
*cmds = append(*cmds, milestone.NewReachedEvent(ctx, aggregate, milestone.ProjectCreated))
return c.invalidateMilestoneCachePostCommit(aggregate.InstanceID), nil
}
func (c *Commands) applicationCreatedMilestone(ctx context.Context, cmds *[]eventstore.Command) (postCommit func(ctx context.Context), err error) {
postCommit = func(ctx context.Context) {}
if isSystemUser(ctx) {
return postCommit, nil
}
milestones, err := c.GetMilestonesReached(ctx)
if err != nil {
return postCommit, err
}
if milestones.ApplicationCreated {
return postCommit, nil
}
aggregate := milestone.NewAggregate(ctx)
*cmds = append(*cmds, milestone.NewReachedEvent(ctx, aggregate, milestone.ApplicationCreated))
return c.invalidateMilestoneCachePostCommit(aggregate.InstanceID), nil
}
func (c *Commands) invalidateMilestoneCachePostCommit(instanceID string) func(ctx context.Context) {
return func(ctx context.Context) {
err := c.caches.milestones.Invalidate(ctx, milestoneIndexInstanceID, instanceID)
logging.WithFields("instance_id", instanceID).OnError(err).Error("failed to invalidate milestone cache")
}
}
func isSystemUser(ctx context.Context) bool {
return authz.GetCtxData(ctx).SystemMemberships != nil
}

View File

@ -0,0 +1,58 @@
package command
import (
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/milestone"
)
type MilestonesReachedWriteModel struct {
eventstore.WriteModel
MilestonesReached
}
func NewMilestonesReachedWriteModel(instanceID string) *MilestonesReachedWriteModel {
return &MilestonesReachedWriteModel{
WriteModel: eventstore.WriteModel{
AggregateID: instanceID,
InstanceID: instanceID,
},
MilestonesReached: MilestonesReached{
InstanceID: instanceID,
},
}
}
func (m *MilestonesReachedWriteModel) Query() *eventstore.SearchQueryBuilder {
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes(milestone.AggregateType).
AggregateIDs(m.AggregateID).
EventTypes(milestone.ReachedEventType, milestone.PushedEventType).
Builder()
}
func (m *MilestonesReachedWriteModel) Reduce() error {
for _, event := range m.Events {
if e, ok := event.(*milestone.ReachedEvent); ok {
m.reduceReachedEvent(e)
}
}
return m.WriteModel.Reduce()
}
func (m *MilestonesReachedWriteModel) reduceReachedEvent(e *milestone.ReachedEvent) {
switch e.MilestoneType {
case milestone.InstanceCreated:
m.InstanceCreated = true
case milestone.AuthenticationSucceededOnInstance:
m.AuthenticationSucceededOnInstance = true
case milestone.ProjectCreated:
m.ProjectCreated = true
case milestone.ApplicationCreated:
m.ApplicationCreated = true
case milestone.AuthenticationSucceededOnApplication:
m.AuthenticationSucceededOnApplication = true
case milestone.InstanceDeleted:
m.InstanceDeleted = true
}
}

View File

@ -0,0 +1,629 @@
package command
import (
"context"
"io"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/cache"
"github.com/zitadel/zitadel/internal/cache/gomap"
"github.com/zitadel/zitadel/internal/cache/noop"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/milestone"
)
func TestCommands_GetMilestonesReached(t *testing.T) {
cached := &MilestonesReached{
InstanceID: "cached-id",
InstanceCreated: true,
AuthenticationSucceededOnInstance: true,
}
ctx := authz.WithInstanceID(context.Background(), "instanceID")
aggregate := milestone.NewAggregate(ctx)
type fields struct {
eventstore func(*testing.T) *eventstore.Eventstore
}
type args struct {
ctx context.Context
}
tests := []struct {
name string
fields fields
args args
want *MilestonesReached
wantErr error
}{
{
name: "cached",
fields: fields{
eventstore: expectEventstore(),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "cached-id"),
},
want: cached,
},
{
name: "filter error",
fields: fields{
eventstore: expectEventstore(
expectFilterError(io.ErrClosedPipe),
),
},
args: args{
ctx: ctx,
},
wantErr: io.ErrClosedPipe,
},
{
name: "no events, all false",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
},
},
{
name: "instance created",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.InstanceCreated)),
),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
InstanceCreated: true,
},
},
{
name: "instance auth",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
AuthenticationSucceededOnInstance: true,
},
},
{
name: "project created",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.ProjectCreated)),
),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
ProjectCreated: true,
},
},
{
name: "app created",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.ApplicationCreated)),
),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
ApplicationCreated: true,
},
},
{
name: "app auth",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication)),
),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
AuthenticationSucceededOnApplication: true,
},
},
{
name: "instance deleted",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.InstanceDeleted)),
),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
InstanceDeleted: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cache := gomap.NewCache[milestoneIndex, string, *MilestonesReached](
context.Background(),
[]milestoneIndex{milestoneIndexInstanceID},
cache.CacheConfig{Connector: "memory"},
)
cache.Set(context.Background(), cached)
c := &Commands{
eventstore: tt.fields.eventstore(t),
caches: &Caches{
milestones: cache,
},
}
got, err := c.GetMilestonesReached(tt.args.ctx)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
})
}
}
func TestCommands_milestonesCompleted(t *testing.T) {
c := &Commands{
caches: &Caches{
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
},
}
ctx := authz.WithInstanceID(context.Background(), "instanceID")
arg := &MilestonesReached{
InstanceID: "instanceID",
InstanceCreated: true,
AuthenticationSucceededOnInstance: true,
ProjectCreated: true,
ApplicationCreated: true,
AuthenticationSucceededOnApplication: true,
InstanceDeleted: false,
}
c.setCachedMilestonesReached(ctx, arg)
got, ok := c.getCachedMilestonesReached(ctx)
assert.True(t, ok)
assert.Equal(t, arg, got)
}
func TestCommands_MilestonePushed(t *testing.T) {
aggregate := milestone.NewInstanceAggregate("instanceID")
type fields struct {
eventstore func(*testing.T) *eventstore.Eventstore
}
type args struct {
ctx context.Context
instanceID string
msType milestone.Type
endpoints []string
}
tests := []struct {
name string
fields fields
args args
wantErr error
}{
{
name: "milestone pushed",
fields: fields{
eventstore: expectEventstore(
expectPush(
milestone.NewPushedEvent(
context.Background(),
aggregate,
milestone.ApplicationCreated,
[]string{"foo.com", "bar.com"},
"example.com",
),
),
),
},
args: args{
ctx: context.Background(),
instanceID: "instanceID",
msType: milestone.ApplicationCreated,
endpoints: []string{"foo.com", "bar.com"},
},
wantErr: nil,
},
{
name: "pusher error",
fields: fields{
eventstore: expectEventstore(
expectPushFailed(
io.ErrClosedPipe,
milestone.NewPushedEvent(
context.Background(),
aggregate,
milestone.ApplicationCreated,
[]string{"foo.com", "bar.com"},
"example.com",
),
),
),
},
args: args{
ctx: context.Background(),
instanceID: "instanceID",
msType: milestone.ApplicationCreated,
endpoints: []string{"foo.com", "bar.com"},
},
wantErr: io.ErrClosedPipe,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore(t),
externalDomain: "example.com",
}
err := c.MilestonePushed(tt.args.ctx, tt.args.instanceID, tt.args.msType, tt.args.endpoints)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestOIDCSessionEvents_SetMilestones(t *testing.T) {
ctx := authz.WithInstanceID(context.Background(), "instanceID")
ctx = authz.WithConsoleClientID(ctx, "console")
aggregate := milestone.NewAggregate(ctx)
type fields struct {
eventstore func(*testing.T) *eventstore.Eventstore
}
type args struct {
ctx context.Context
clientID string
isHuman bool
}
tests := []struct {
name string
fields fields
args args
wantEvents []eventstore.Command
wantErr error
}{
{
name: "get error",
fields: fields{
eventstore: expectEventstore(
expectFilterError(io.ErrClosedPipe),
),
},
args: args{
ctx: ctx,
clientID: "client",
isHuman: true,
},
wantErr: io.ErrClosedPipe,
},
{
name: "milestones already reached",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication)),
),
),
},
args: args{
ctx: ctx,
clientID: "client",
isHuman: true,
},
wantErr: nil,
},
{
name: "auth on instance",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
),
},
args: args{
ctx: ctx,
clientID: "console",
isHuman: true,
},
wantEvents: []eventstore.Command{
milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance),
},
wantErr: nil,
},
{
name: "subsequent console login, no milestone",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
),
),
},
args: args{
ctx: ctx,
clientID: "console",
isHuman: true,
},
wantErr: nil,
},
{
name: "subsequent machine login, no milestone",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
),
),
},
args: args{
ctx: ctx,
clientID: "client",
isHuman: false,
},
wantErr: nil,
},
{
name: "auth on app",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
),
),
},
args: args{
ctx: ctx,
clientID: "client",
isHuman: true,
},
wantEvents: []eventstore.Command{
milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication),
},
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore(t),
caches: &Caches{
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
},
}
s := &OIDCSessionEvents{
commands: c,
}
postCommit, err := s.SetMilestones(tt.args.ctx, tt.args.clientID, tt.args.isHuman)
postCommit(tt.args.ctx)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantEvents, s.events)
})
}
}
func TestCommands_projectCreatedMilestone(t *testing.T) {
ctx := authz.WithInstanceID(context.Background(), "instanceID")
systemCtx := authz.SetCtxData(ctx, authz.CtxData{
SystemMemberships: authz.Memberships{
&authz.Membership{
MemberType: authz.MemberTypeSystem,
},
},
})
aggregate := milestone.NewAggregate(ctx)
type fields struct {
eventstore func(*testing.T) *eventstore.Eventstore
}
type args struct {
ctx context.Context
}
tests := []struct {
name string
fields fields
args args
wantEvents []eventstore.Command
wantErr error
}{
{
name: "system user",
fields: fields{
eventstore: expectEventstore(),
},
args: args{
ctx: systemCtx,
},
wantErr: nil,
},
{
name: "get error",
fields: fields{
eventstore: expectEventstore(
expectFilterError(io.ErrClosedPipe),
),
},
args: args{
ctx: ctx,
},
wantErr: io.ErrClosedPipe,
},
{
name: "milestone already reached",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.ProjectCreated)),
),
),
},
args: args{
ctx: ctx,
},
wantErr: nil,
},
{
name: "milestone reached event",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
),
},
args: args{
ctx: ctx,
},
wantEvents: []eventstore.Command{
milestone.NewReachedEvent(ctx, aggregate, milestone.ProjectCreated),
},
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore(t),
caches: &Caches{
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
},
}
var cmds []eventstore.Command
postCommit, err := c.projectCreatedMilestone(tt.args.ctx, &cmds)
postCommit(tt.args.ctx)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantEvents, cmds)
})
}
}
func TestCommands_applicationCreatedMilestone(t *testing.T) {
ctx := authz.WithInstanceID(context.Background(), "instanceID")
systemCtx := authz.SetCtxData(ctx, authz.CtxData{
SystemMemberships: authz.Memberships{
&authz.Membership{
MemberType: authz.MemberTypeSystem,
},
},
})
aggregate := milestone.NewAggregate(ctx)
type fields struct {
eventstore func(*testing.T) *eventstore.Eventstore
}
type args struct {
ctx context.Context
}
tests := []struct {
name string
fields fields
args args
wantEvents []eventstore.Command
wantErr error
}{
{
name: "system user",
fields: fields{
eventstore: expectEventstore(),
},
args: args{
ctx: systemCtx,
},
wantErr: nil,
},
{
name: "get error",
fields: fields{
eventstore: expectEventstore(
expectFilterError(io.ErrClosedPipe),
),
},
args: args{
ctx: ctx,
},
wantErr: io.ErrClosedPipe,
},
{
name: "milestone already reached",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.ApplicationCreated)),
),
),
},
args: args{
ctx: ctx,
},
wantErr: nil,
},
{
name: "milestone reached event",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
),
},
args: args{
ctx: ctx,
},
wantEvents: []eventstore.Command{
milestone.NewReachedEvent(ctx, aggregate, milestone.ApplicationCreated),
},
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore(t),
caches: &Caches{
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
},
}
var cmds []eventstore.Command
postCommit, err := c.applicationCreatedMilestone(tt.args.ctx, &cmds)
postCommit(tt.args.ctx)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantEvents, cmds)
})
}
}
func (c *Commands) setMilestonesCompletedForTest(instanceID string) {
c.milestonesCompleted.Store(instanceID, struct{}{})
}

View File

@ -71,7 +71,8 @@ func (c *Commands) CreateOIDCSessionFromAuthRequest(ctx context.Context, authReq
return nil, "", zerrors.ThrowPreconditionFailed(nil, "COMMAND-Iung5", "Errors.AuthRequest.NoCode")
}
sessionModel := NewSessionWriteModel(authReqModel.SessionID, authz.GetInstance(ctx).InstanceID())
instanceID := authz.GetInstance(ctx).InstanceID()
sessionModel := NewSessionWriteModel(authReqModel.SessionID, instanceID)
err = c.eventstore.FilterToQueryReducer(ctx, sessionModel)
if err != nil {
return nil, "", err
@ -118,8 +119,15 @@ func (c *Commands) CreateOIDCSessionFromAuthRequest(ctx context.Context, authReq
}
}
cmd.SetAuthRequestSuccessful(ctx, authReqModel.aggregate)
session, err = cmd.PushEvents(ctx)
return session, authReqModel.State, err
postCommit, err := cmd.SetMilestones(ctx, authReqModel.ClientID, true)
if err != nil {
return nil, "", err
}
if session, err = cmd.PushEvents(ctx); err != nil {
return nil, "", err
}
postCommit(ctx)
return session, authReqModel.State, nil
}
func (c *Commands) CreateOIDCSession(ctx context.Context,
@ -161,7 +169,15 @@ func (c *Commands) CreateOIDCSession(ctx context.Context,
return nil, err
}
}
return cmd.PushEvents(ctx)
postCommit, err := cmd.SetMilestones(ctx, clientID, sessionID != "")
if err != nil {
return nil, err
}
if session, err = cmd.PushEvents(ctx); err != nil {
return nil, err
}
postCommit(ctx)
return session, nil
}
type RefreshTokenComplianceChecker func(ctx context.Context, wm *OIDCSessionWriteModel, requestedScope []string) (scope []string, err error)
@ -283,7 +299,7 @@ func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, userID, resource
}
sessionID = IDPrefixV2 + sessionID
return &OIDCSessionEvents{
eventstore: c.eventstore,
commands: c,
idGenerator: c.idGenerator,
encryptionAlg: c.keyAlgorithm,
events: pending,
@ -341,7 +357,7 @@ func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, refreshToken
return nil, err
}
return &OIDCSessionEvents{
eventstore: c.eventstore,
commands: c,
idGenerator: c.idGenerator,
encryptionAlg: c.keyAlgorithm,
oidcSessionWriteModel: sessionWriteModel,
@ -352,7 +368,7 @@ func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, refreshToken
}
type OIDCSessionEvents struct {
eventstore *eventstore.Eventstore
commands *Commands
idGenerator id.Generator
encryptionAlg crypto.EncryptionAlgorithm
events []eventstore.Command
@ -467,7 +483,7 @@ func (c *OIDCSessionEvents) generateRefreshToken(userID string) (refreshTokenID,
}
func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (*OIDCSession, error) {
pushedEvents, err := c.eventstore.Push(ctx, c.events...)
pushedEvents, err := c.commands.eventstore.Push(ctx, c.events...)
if err != nil {
return nil, err
}
@ -496,7 +512,7 @@ func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (*OIDCSession, error
// we need to use `-` as a delimiter because the OIDC library uses `:` and will check for a length of 2 parts
session.TokenID = c.oidcSessionWriteModel.AggregateID + TokenDelimiter + c.accessTokenID
}
activity.Trigger(ctx, c.oidcSessionWriteModel.UserResourceOwner, c.oidcSessionWriteModel.UserID, tokenReasonToActivityMethodType(c.oidcSessionWriteModel.AccessTokenReason), c.eventstore.FilterToQueryReducer)
activity.Trigger(ctx, c.oidcSessionWriteModel.UserResourceOwner, c.oidcSessionWriteModel.UserID, tokenReasonToActivityMethodType(c.oidcSessionWriteModel.AccessTokenReason), c.commands.eventstore.FilterToQueryReducer)
return session, nil
}

View File

@ -70,7 +70,7 @@ func TestCommands_CreateOIDCSessionFromAuthRequest(t *testing.T) {
eventstore: expectEventstore(),
},
args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
authRequestID: "",
complianceCheck: mockAuthRequestComplianceChecker(nil),
},
@ -86,7 +86,7 @@ func TestCommands_CreateOIDCSessionFromAuthRequest(t *testing.T) {
),
},
args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
authRequestID: "V2_authRequestID",
complianceCheck: mockAuthRequestComplianceChecker(nil),
},
@ -102,7 +102,7 @@ func TestCommands_CreateOIDCSessionFromAuthRequest(t *testing.T) {
),
},
args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
authRequestID: "V2_authRequestID",
complianceCheck: mockAuthRequestComplianceChecker(nil),
},
@ -706,6 +706,7 @@ func TestCommands_CreateOIDCSessionFromAuthRequest(t *testing.T) {
defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime,
keyAlgorithm: tt.fields.keyAlgorithm,
}
c.setMilestonesCompletedForTest("instanceID")
gotSession, gotState, err := c.CreateOIDCSessionFromAuthRequest(tt.args.ctx, tt.args.authRequestID, tt.args.complianceCheck, tt.args.needRefreshToken)
require.ErrorIs(t, err, tt.res.err)
@ -762,7 +763,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
userID: "userID",
resourceOwner: "orgID",
clientID: "clientID",
@ -818,7 +819,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
userID: "userID",
resourceOwner: "org1",
clientID: "clientID",
@ -892,7 +893,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
userID: "userID",
resourceOwner: "org1",
clientID: "clientID",
@ -1089,7 +1090,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
userID: "userID",
resourceOwner: "org1",
clientID: "clientID",
@ -1186,7 +1187,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
userID: "userID",
resourceOwner: "org1",
clientID: "clientID",
@ -1266,7 +1267,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
}),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
userID: "userID",
resourceOwner: "org1",
clientID: "clientID",
@ -1347,7 +1348,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
}),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
userID: "userID",
resourceOwner: "org1",
clientID: "clientID",
@ -1406,6 +1407,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
keyAlgorithm: tt.fields.keyAlgorithm,
checkPermission: tt.fields.checkPermission,
}
c.setMilestonesCompletedForTest("instanceID")
got, err := c.CreateOIDCSession(tt.args.ctx,
tt.args.userID,
tt.args.resourceOwner,

View File

@ -34,7 +34,11 @@ func (c *Commands) AddProjectWithID(ctx context.Context, project *domain.Project
if existingProject.State != domain.ProjectStateUnspecified {
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-opamwu", "Errors.Project.AlreadyExisting")
}
return c.addProjectWithID(ctx, project, resourceOwner, projectID)
project, err = c.addProjectWithID(ctx, project, resourceOwner, projectID)
if err != nil {
return nil, err
}
return project, nil
}
func (c *Commands) AddProject(ctx context.Context, project *domain.Project, resourceOwner, ownerUserID string) (_ *domain.Project, err error) {
@ -53,7 +57,11 @@ func (c *Commands) AddProject(ctx context.Context, project *domain.Project, reso
return nil, err
}
return c.addProjectWithIDWithOwner(ctx, project, resourceOwner, ownerUserID, projectID)
project, err = c.addProjectWithIDWithOwner(ctx, project, resourceOwner, ownerUserID, projectID)
if err != nil {
return nil, err
}
return project, nil
}
func (c *Commands) addProjectWithID(ctx context.Context, projectAdd *domain.Project, resourceOwner, projectID string) (_ *domain.Project, err error) {
@ -71,11 +79,15 @@ func (c *Commands) addProjectWithID(ctx context.Context, projectAdd *domain.Proj
projectAdd.HasProjectCheck,
projectAdd.PrivateLabelingSetting),
}
postCommit, err := c.projectCreatedMilestone(ctx, &events)
if err != nil {
return nil, err
}
pushedEvents, err := c.eventstore.Push(ctx, events...)
if err != nil {
return nil, err
}
postCommit(ctx)
err = AppendAndReduce(addedProject, pushedEvents...)
if err != nil {
return nil, err
@ -103,11 +115,15 @@ func (c *Commands) addProjectWithIDWithOwner(ctx context.Context, projectAdd *do
projectAdd.PrivateLabelingSetting),
project.NewProjectMemberAddedEvent(ctx, projectAgg, ownerUserID, projectRole),
}
postCommit, err := c.projectCreatedMilestone(ctx, &events)
if err != nil {
return nil, err
}
pushedEvents, err := c.eventstore.Push(ctx, events...)
if err != nil {
return nil, err
}
postCommit(ctx)
err = AppendAndReduce(addedProject, pushedEvents...)
if err != nil {
return nil, err

View File

@ -202,10 +202,15 @@ func (c *Commands) addOIDCApplicationWithID(ctx context.Context, oidcApp *domain
))
addedApplication.AppID = oidcApp.AppID
postCommit, err := c.applicationCreatedMilestone(ctx, &events)
if err != nil {
return nil, err
}
pushedEvents, err := c.eventstore.Push(ctx, events...)
if err != nil {
return nil, err
}
postCommit(ctx)
err = AppendAndReduce(addedApplication, pushedEvents...)
if err != nil {
return nil, err

View File

@ -11,6 +11,7 @@ import (
"github.com/zitadel/passwap"
"github.com/zitadel/passwap/bcrypt"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/command/preparation"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
@ -418,7 +419,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
eventstore: expectEventstore(),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcApp: &domain.OIDCApp{},
resourceOwner: "org1",
},
@ -434,7 +435,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcApp: &domain.OIDCApp{
ObjectRoot: models.ObjectRoot{
AggregateID: "project1",
@ -463,7 +464,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcApp: &domain.OIDCApp{
ObjectRoot: models.ObjectRoot{
AggregateID: "project1",
@ -521,7 +522,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "app1", "client1"),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcApp: &domain.OIDCApp{
ObjectRoot: models.ObjectRoot{
AggregateID: "project1",
@ -619,7 +620,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "app1", "client1"),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcApp: &domain.OIDCApp{
ObjectRoot: models.ObjectRoot{
AggregateID: "project1",
@ -676,7 +677,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &Commands{
c := &Commands{
eventstore: tt.fields.eventstore(t),
idGenerator: tt.fields.idGenerator,
newHashedSecret: mockHashedSecret("secret"),
@ -684,7 +685,8 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
ClientSecret: emptyConfig,
},
}
got, err := r.AddOIDCApplication(tt.args.ctx, tt.args.oidcApp, tt.args.resourceOwner)
c.setMilestonesCompletedForTest("instanceID")
got, err := c.AddOIDCApplication(tt.args.ctx, tt.args.oidcApp, tt.args.resourceOwner)
if tt.res.err == nil {
assert.NoError(t, err)
}

View File

@ -28,10 +28,15 @@ func (c *Commands) AddSAMLApplication(ctx context.Context, application *domain.S
return nil, err
}
addedApplication.AppID = application.AppID
postCommit, err := c.applicationCreatedMilestone(ctx, &events)
if err != nil {
return nil, err
}
pushedEvents, err := c.eventstore.Push(ctx, events...)
if err != nil {
return nil, err
}
postCommit(ctx)
err = AppendAndReduce(addedApplication, pushedEvents...)
if err != nil {
return nil, err

View File

@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
@ -76,7 +77,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
samlApp: &domain.SAMLApp{},
resourceOwner: "org1",
},
@ -93,7 +94,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
samlApp: &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{
AggregateID: "project1",
@ -123,7 +124,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
samlApp: &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{
AggregateID: "project1",
@ -154,7 +155,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
idGenerator: id_mock.NewIDGeneratorExpectIDs(t),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
samlApp: &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{
AggregateID: "project1",
@ -201,7 +202,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "app1"),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
samlApp: &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{
AggregateID: "project1",
@ -260,7 +261,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
httpClient: newTestClient(200, testMetadata),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
samlApp: &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{
AggregateID: "project1",
@ -305,7 +306,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
httpClient: newTestClient(http.StatusNotFound, nil),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
samlApp: &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{
AggregateID: "project1",
@ -325,13 +326,13 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &Commands{
c := &Commands{
eventstore: tt.fields.eventstore,
idGenerator: tt.fields.idGenerator,
httpClient: tt.fields.httpClient,
}
got, err := r.AddSAMLApplication(tt.args.ctx, tt.args.samlApp, tt.args.resourceOwner)
c.setMilestonesCompletedForTest("instanceID")
got, err := c.AddSAMLApplication(tt.args.ctx, tt.args.samlApp, tt.args.resourceOwner)
if tt.res.err == nil {
assert.NoError(t, err)
}

View File

@ -6,6 +6,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
@ -44,7 +45,7 @@ func TestCommandSide_AddProject(t *testing.T) {
),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
project: &domain.Project{},
resourceOwner: "org1",
},
@ -60,7 +61,7 @@ func TestCommandSide_AddProject(t *testing.T) {
),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
project: &domain.Project{
Name: "project",
ProjectRoleAssertion: true,
@ -121,7 +122,7 @@ func TestCommandSide_AddProject(t *testing.T) {
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "project1"),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
project: &domain.Project{
Name: "project",
ProjectRoleAssertion: true,
@ -159,7 +160,7 @@ func TestCommandSide_AddProject(t *testing.T) {
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "project1"),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
project: &domain.Project{
Name: "project",
ProjectRoleAssertion: true,
@ -187,11 +188,12 @@ func TestCommandSide_AddProject(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &Commands{
c := &Commands{
eventstore: tt.fields.eventstore,
idGenerator: tt.fields.idGenerator,
}
got, err := r.AddProject(tt.args.ctx, tt.args.project, tt.args.resourceOwner, tt.args.ownerID)
c.setMilestonesCompletedForTest("instanceID")
got, err := c.AddProject(tt.args.ctx, tt.args.project, tt.args.resourceOwner, tt.args.ownerID)
if tt.res.err == nil {
assert.NoError(t, err)
}

View File

@ -11,6 +11,7 @@ import (
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/zerrors"
)
@ -247,6 +248,10 @@ func (es *Eventstore) InstanceIDs(ctx context.Context, maxAge time.Duration, for
return instances, nil
}
func (es *Eventstore) Client() *database.DB {
return es.querier.Client()
}
type QueryReducer interface {
reducer
//Query returns the SearchQueryFactory for the events needed in reducer
@ -270,6 +275,8 @@ type Querier interface {
LatestSequence(ctx context.Context, queryFactory *SearchQueryBuilder) (float64, error)
// InstanceIDs returns the instance ids found by the search query
InstanceIDs(ctx context.Context, queryFactory *SearchQueryBuilder) ([]string, error)
// Client returns the underlying database connection
Client() *database.DB
}
type Pusher interface {

View File

@ -12,6 +12,7 @@ import (
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/service"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/zerrors"
)
@ -437,6 +438,10 @@ func (repo *testQuerier) InstanceIDs(ctx context.Context, queryFactory *SearchQu
return repo.instances, nil
}
func (*testQuerier) Client() *database.DB {
return nil
}
func TestEventstore_Push(t *testing.T) {
type args struct {
events []Command

View File

@ -13,6 +13,7 @@ import (
context "context"
reflect "reflect"
database "github.com/zitadel/zitadel/internal/database"
eventstore "github.com/zitadel/zitadel/internal/eventstore"
gomock "go.uber.org/mock/gomock"
)
@ -40,6 +41,20 @@ func (m *MockQuerier) EXPECT() *MockQuerierMockRecorder {
return m.recorder
}
// Client mocks base method.
func (m *MockQuerier) Client() *database.DB {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Client")
ret0, _ := ret[0].(*database.DB)
return ret0
}
// Client indicates an expected call of Client.
func (mr *MockQuerierMockRecorder) Client() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Client", reflect.TypeOf((*MockQuerier)(nil).Client))
}
// FilterToReducer mocks base method.
func (m *MockQuerier) FilterToReducer(arg0 context.Context, arg1 *eventstore.SearchQueryBuilder, arg2 eventstore.Reducer) error {
m.ctrl.T.Helper()

View File

@ -282,7 +282,7 @@ func (db *CRDB) InstanceIDs(ctx context.Context, searchQuery *eventstore.SearchQ
return ids, nil
}
func (db *CRDB) db() *database.DB {
func (db *CRDB) Client() *database.DB {
return db.DB
}

View File

@ -27,7 +27,7 @@ type querier interface {
eventQuery(useV1 bool) string
maxSequenceQuery(useV1 bool) string
instanceIDsQuery(useV1 bool) string
db() *database.DB
Client() *database.DB
orderByEventSequence(desc, shouldOrderBySequence, useV1 bool) string
dialect.Database
}
@ -110,7 +110,7 @@ func query(ctx context.Context, criteria querier, searchQuery *eventstore.Search
var contextQuerier interface {
QueryContext(context.Context, func(rows *sql.Rows) error, string, ...interface{}) error
}
contextQuerier = criteria.db()
contextQuerier = criteria.Client()
if q.Tx != nil {
contextQuerier = &tx{Tx: q.Tx}
}

View File

@ -22,5 +22,5 @@ type Commands interface {
HumanPhoneVerificationCodeSent(ctx context.Context, orgID, userID string, generatorInfo *senders.CodeGeneratorInfo) error
InviteCodeSent(ctx context.Context, orgID, userID string) error
UsageNotificationSent(ctx context.Context, dueEvent *quota.NotificationDueEvent) error
MilestonePushed(ctx context.Context, msType milestone.Type, endpoints []string, primaryDomain string) error
MilestonePushed(ctx context.Context, instanceID string, msType milestone.Type, endpoints []string) error
}

View File

@ -16,6 +16,7 @@ import (
"github.com/zitadel/zitadel/internal/integration"
"github.com/zitadel/zitadel/internal/integration/sink"
"github.com/zitadel/zitadel/internal/repository/milestone"
"github.com/zitadel/zitadel/pkg/grpc/app"
"github.com/zitadel/zitadel/pkg/grpc/management"
"github.com/zitadel/zitadel/pkg/grpc/object"
@ -32,12 +33,12 @@ func TestServer_TelemetryPushMilestones(t *testing.T) {
instance := integration.NewInstance(CTX)
iamOwnerCtx := instance.WithAuthorization(CTX, integration.UserTypeIAMOwner)
t.Log("testing against instance with primary domain", instance.Domain)
awaitMilestone(t, sub, instance.Domain, "InstanceCreated")
t.Log("testing against instance", instance.ID())
awaitMilestone(t, sub, instance.ID(), milestone.InstanceCreated)
projectAdded, err := instance.Client.Mgmt.AddProject(iamOwnerCtx, &management.AddProjectRequest{Name: "integration"})
require.NoError(t, err)
awaitMilestone(t, sub, instance.Domain, "ProjectCreated")
awaitMilestone(t, sub, instance.ID(), milestone.ProjectCreated)
redirectURI := "http://localhost:8888"
application, err := instance.Client.Mgmt.AddOIDCApp(iamOwnerCtx, &management.AddOIDCAppRequest{
@ -52,14 +53,14 @@ func TestServer_TelemetryPushMilestones(t *testing.T) {
AccessTokenType: app.OIDCTokenType_OIDC_TOKEN_TYPE_JWT,
})
require.NoError(t, err)
awaitMilestone(t, sub, instance.Domain, "ApplicationCreated")
awaitMilestone(t, sub, instance.ID(), milestone.ApplicationCreated)
// create the session to be used for the authN of the clients
sessionID, sessionToken, _, _ := instance.CreatePasswordSession(t, iamOwnerCtx, instance.AdminUserID, "Password1!")
console := consoleOIDCConfig(t, instance)
loginToClient(t, instance, console.GetClientId(), console.GetRedirectUris()[0], sessionID, sessionToken)
awaitMilestone(t, sub, instance.Domain, "AuthenticationSucceededOnInstance")
awaitMilestone(t, sub, instance.ID(), milestone.AuthenticationSucceededOnInstance)
// make sure the client has been projected
require.EventuallyWithT(t, func(collectT *assert.CollectT) {
@ -70,11 +71,11 @@ func TestServer_TelemetryPushMilestones(t *testing.T) {
assert.NoError(collectT, err)
}, time.Minute, time.Second, "app not found")
loginToClient(t, instance, application.GetClientId(), redirectURI, sessionID, sessionToken)
awaitMilestone(t, sub, instance.Domain, "AuthenticationSucceededOnApplication")
awaitMilestone(t, sub, instance.ID(), milestone.AuthenticationSucceededOnApplication)
_, err = integration.SystemClient().RemoveInstance(CTX, &system.RemoveInstanceRequest{InstanceId: instance.ID()})
require.NoError(t, err)
awaitMilestone(t, sub, instance.Domain, "InstanceDeleted")
awaitMilestone(t, sub, instance.ID(), milestone.InstanceDeleted)
}
func loginToClient(t *testing.T, instance *integration.Instance, clientID, redirectURI, sessionID, sessionToken string) {
@ -134,7 +135,7 @@ func consoleOIDCConfig(t *testing.T, instance *integration.Instance) *app.OIDCCo
return apps.GetResult()[0].GetOidcConfig()
}
func awaitMilestone(t *testing.T, sub *sink.Subscription, primaryDomain, expectMilestoneType string) {
func awaitMilestone(t *testing.T, sub *sink.Subscription, instanceID string, expectMilestoneType milestone.Type) {
for {
select {
case req := <-sub.Recv():
@ -144,17 +145,17 @@ func awaitMilestone(t *testing.T, sub *sink.Subscription, primaryDomain, expectM
}
t.Log("received milestone", plain.String())
milestone := struct {
Type string `json:"type"`
PrimaryDomain string `json:"primaryDomain"`
InstanceID string `json:"instanceId"`
Type milestone.Type `json:"type"`
}{}
if err := json.Unmarshal(req.Body, &milestone); err != nil {
t.Error(err)
}
if milestone.Type == expectMilestoneType && milestone.PrimaryDomain == primaryDomain {
if milestone.Type == expectMilestoneType && milestone.InstanceID == instanceID {
return
}
case <-time.After(2 * time.Minute): // why does it take so long to get a milestone !?
t.Fatalf("timed out waiting for milestone %s in domain %s", expectMilestoneType, primaryDomain)
case <-time.After(20 * time.Second):
t.Fatalf("timed out waiting for milestone %s for instance %s", expectMilestoneType, instanceID)
}
}
}

View File

@ -141,7 +141,7 @@ func (mr *MockCommandsMockRecorder) InviteCodeSent(arg0, arg1, arg2 any) *gomock
}
// MilestonePushed mocks base method.
func (m *MockCommands) MilestonePushed(arg0 context.Context, arg1 milestone.Type, arg2 []string, arg3 string) error {
func (m *MockCommands) MilestonePushed(arg0 context.Context, arg1 string, arg2 milestone.Type, arg3 []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MilestonePushed", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(error)

View File

@ -2,13 +2,9 @@ package handlers
import (
"context"
"fmt"
"net/http"
"time"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/eventstore"
@ -16,9 +12,7 @@ import (
"github.com/zitadel/zitadel/internal/notification/channels/webhook"
_ "github.com/zitadel/zitadel/internal/notification/statik"
"github.com/zitadel/zitadel/internal/notification/types"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/repository/milestone"
"github.com/zitadel/zitadel/internal/repository/pseudo"
"github.com/zitadel/zitadel/internal/zerrors"
)
@ -30,7 +24,6 @@ type TelemetryPusherConfig struct {
Enabled bool
Endpoints []string
Headers http.Header
Limit uint64
}
type telemetryPusher struct {
@ -54,7 +47,6 @@ func NewTelemetryPusher(
queries: queries,
channels: channels,
}
handlerCfg.TriggerWithoutEvents = pusher.pushMilestones
return handler.NewHandler(
ctx,
&handlerCfg,
@ -68,9 +60,9 @@ func (u *telemetryPusher) Name() string {
func (t *telemetryPusher) Reducers() []handler.AggregateReducer {
return []handler.AggregateReducer{{
Aggregate: pseudo.AggregateType,
Aggregate: milestone.AggregateType,
EventReducers: []handler.EventReducer{{
Event: pseudo.ScheduledEventType,
Event: milestone.ReachedEventType,
Reduce: t.pushMilestones,
}},
}}
@ -78,51 +70,20 @@ func (t *telemetryPusher) Reducers() []handler.AggregateReducer {
func (t *telemetryPusher) pushMilestones(event eventstore.Event) (*handler.Statement, error) {
ctx := call.WithTimestamp(context.Background())
scheduledEvent, ok := event.(*pseudo.ScheduledEvent)
e, ok := event.(*milestone.ReachedEvent)
if !ok {
return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-lDTs5", "reduce.wrong.event.type %s", event.Type())
}
return handler.NewStatement(event, func(ex handler.Executer, projectionName string) error {
isReached, err := query.NewNotNullQuery(query.MilestoneReachedDateColID)
if err != nil {
return err
return handler.NewStatement(event, func(handler.Executer, string) error {
// Do not push the milestone again if this was a migration event.
if e.ReachedDate != nil {
return nil
}
isNotPushed, err := query.NewIsNullQuery(query.MilestonePushedDateColID)
if err != nil {
return err
}
hasPrimaryDomain, err := query.NewNotNullQuery(query.MilestonePrimaryDomainColID)
if err != nil {
return err
}
unpushedMilestones, err := t.queries.Queries.SearchMilestones(ctx, scheduledEvent.InstanceIDs, &query.MilestonesSearchQueries{
SearchRequest: query.SearchRequest{
Limit: t.cfg.Limit,
SortingColumn: query.MilestoneReachedDateColID,
Asc: true,
},
Queries: []query.SearchQuery{isReached, isNotPushed, hasPrimaryDomain},
})
if err != nil {
return err
}
var errs int
for _, ms := range unpushedMilestones.Milestones {
if err = t.pushMilestone(ctx, scheduledEvent, ms); err != nil {
errs++
logging.Warnf("pushing milestone %+v failed: %s", *ms, err.Error())
}
}
if errs > 0 {
return fmt.Errorf("pushing %d of %d milestones failed", errs, unpushedMilestones.Count)
}
return nil
return t.pushMilestone(ctx, e)
}), nil
}
func (t *telemetryPusher) pushMilestone(ctx context.Context, event *pseudo.ScheduledEvent, ms *query.Milestone) error {
ctx = authz.WithInstanceID(ctx, ms.InstanceID)
func (t *telemetryPusher) pushMilestone(ctx context.Context, e *milestone.ReachedEvent) error {
for _, endpoint := range t.cfg.Endpoints {
if err := types.SendJSON(
ctx,
@ -135,20 +96,18 @@ func (t *telemetryPusher) pushMilestone(ctx context.Context, event *pseudo.Sched
&struct {
InstanceID string `json:"instanceId"`
ExternalDomain string `json:"externalDomain"`
PrimaryDomain string `json:"primaryDomain"`
Type milestone.Type `json:"type"`
ReachedDate time.Time `json:"reached"`
}{
InstanceID: ms.InstanceID,
InstanceID: e.Agg.InstanceID,
ExternalDomain: t.queries.externalDomain,
PrimaryDomain: ms.PrimaryDomain,
Type: ms.Type,
ReachedDate: ms.ReachedDate,
Type: e.MilestoneType,
ReachedDate: e.GetReachedDate(),
},
event,
e,
).WithoutTemplate(); err != nil {
return err
}
}
return t.commands.MilestonePushed(ctx, ms.Type, t.cfg.Endpoints, ms.PrimaryDomain)
return t.commands.MilestonePushed(ctx, e.Agg.InstanceID, e.MilestoneType, t.cfg.Endpoints)
}

View File

@ -54,10 +54,6 @@ var (
name: projection.MilestoneColumnType,
table: milestonesTable,
}
MilestonePrimaryDomainColID = Column{
name: projection.MilestoneColumnPrimaryDomain,
table: milestonesTable,
}
MilestoneReachedDateColID = Column{
name: projection.MilestoneColumnReachedDate,
table: milestonesTable,
@ -76,7 +72,10 @@ func (q *Queries) SearchMilestones(ctx context.Context, instanceIDs []string, qu
if len(instanceIDs) == 0 {
instanceIDs = []string{authz.GetInstance(ctx).InstanceID()}
}
stmt, args, err := queries.toQuery(query).Where(sq.Eq{MilestoneInstanceIDColID.identifier(): instanceIDs}).ToSql()
stmt, args, err := queries.toQuery(query).Where(
sq.Eq{MilestoneInstanceIDColID.identifier(): instanceIDs},
sq.Eq{InstanceDomainIsPrimaryCol.identifier(): true},
).ToSql()
if err != nil {
return nil, zerrors.ThrowInternal(err, "QUERY-A9i5k", "Errors.Query.SQLStatement")
}
@ -96,13 +95,14 @@ func (q *Queries) SearchMilestones(ctx context.Context, instanceIDs []string, qu
func prepareMilestonesQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Milestones, error)) {
return sq.Select(
MilestoneInstanceIDColID.identifier(),
MilestonePrimaryDomainColID.identifier(),
InstanceDomainDomainCol.identifier(),
MilestoneReachedDateColID.identifier(),
MilestonePushedDateColID.identifier(),
MilestoneTypeColID.identifier(),
countColumn.identifier(),
).
From(milestonesTable.identifier() + db.Timetravel(call.Took(ctx))).
LeftJoin(join(InstanceDomainInstanceIDCol, MilestoneInstanceIDColID)).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*Milestones, error) {
milestones := make([]*Milestone, 0)

View File

@ -11,13 +11,14 @@ import (
var (
expectedMilestoneQuery = regexp.QuoteMeta(`
SELECT projections.milestones.instance_id,
projections.milestones.primary_domain,
projections.milestones.reached_date,
projections.milestones.last_pushed_date,
projections.milestones.type,
SELECT projections.milestones2.instance_id,
projections.instance_domains.domain,
projections.milestones2.reached_date,
projections.milestones2.last_pushed_date,
projections.milestones2.type,
COUNT(*) OVER ()
FROM projections.milestones AS OF SYSTEM TIME '-1 ms'
FROM projections.milestones2 AS OF SYSTEM TIME '-1 ms'
LEFT JOIN projections.instance_domains ON projections.milestones2.instance_id = projections.instance_domains.instance_id
`)
milestoneCols = []string{

View File

@ -14,8 +14,9 @@ func testEvent(
eventType eventstore.EventType,
aggregateType eventstore.AggregateType,
data []byte,
opts ...eventOption,
) *repository.Event {
return timedTestEvent(eventType, aggregateType, data, time.Now())
return timedTestEvent(eventType, aggregateType, data, time.Now(), opts...)
}
func toSystemEvent(event *repository.Event) *repository.Event {
@ -28,8 +29,9 @@ func timedTestEvent(
aggregateType eventstore.AggregateType,
data []byte,
creationDate time.Time,
opts ...eventOption,
) *repository.Event {
return &repository.Event{
e := &repository.Event{
Seq: 15,
CreationDate: creationDate,
Typ: eventType,
@ -42,6 +44,18 @@ func timedTestEvent(
ID: "event-id",
EditorUser: "editor-user",
}
for _, opt := range opts {
opt(e)
}
return e
}
type eventOption func(e *repository.Event)
func withVersion(v eventstore.Version) eventOption {
return func(e *repository.Event) {
e.Version = v
}
}
func baseEvent(*testing.T) eventstore.Event {

View File

@ -2,35 +2,26 @@ package projection
import (
"context"
"strconv"
internal_authz "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/eventstore"
old_handler "github.com/zitadel/zitadel/internal/eventstore/handler"
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
"github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/milestone"
"github.com/zitadel/zitadel/internal/repository/oidcsession"
"github.com/zitadel/zitadel/internal/repository/project"
)
const (
MilestonesProjectionTable = "projections.milestones"
MilestonesProjectionTable = "projections.milestones2"
MilestoneColumnInstanceID = "instance_id"
MilestoneColumnType = "type"
MilestoneColumnPrimaryDomain = "primary_domain"
MilestoneColumnReachedDate = "reached_date"
MilestoneColumnPushedDate = "last_pushed_date"
MilestoneColumnIgnoreClientIDs = "ignore_client_ids"
MilestoneColumnInstanceID = "instance_id"
MilestoneColumnType = "type"
MilestoneColumnReachedDate = "reached_date"
MilestoneColumnPushedDate = "last_pushed_date"
)
type milestoneProjection struct {
systemUsers map[string]*internal_authz.SystemAPIUser
}
type milestoneProjection struct{}
func newMilestoneProjection(ctx context.Context, config handler.Config, systemUsers map[string]*internal_authz.SystemAPIUser) *handler.Handler {
return handler.NewHandler(ctx, &config, &milestoneProjection{systemUsers: systemUsers})
func newMilestoneProjection(ctx context.Context, config handler.Config) *handler.Handler {
return handler.NewHandler(ctx, &config, &milestoneProjection{})
}
func (*milestoneProjection) Name() string {
@ -44,8 +35,6 @@ func (*milestoneProjection) Init() *old_handler.Check {
handler.NewColumn(MilestoneColumnType, handler.ColumnTypeEnum),
handler.NewColumn(MilestoneColumnReachedDate, handler.ColumnTypeTimestamp, handler.Nullable()),
handler.NewColumn(MilestoneColumnPushedDate, handler.ColumnTypeTimestamp, handler.Nullable()),
handler.NewColumn(MilestoneColumnPrimaryDomain, handler.ColumnTypeText, handler.Nullable()),
handler.NewColumn(MilestoneColumnIgnoreClientIDs, handler.ColumnTypeTextArray, handler.Nullable()),
},
handler.NewPrimaryKey(MilestoneColumnInstanceID, MilestoneColumnType),
),
@ -55,183 +44,47 @@ func (*milestoneProjection) Init() *old_handler.Check {
// Reducers implements handler.Projection.
func (p *milestoneProjection) Reducers() []handler.AggregateReducer {
return []handler.AggregateReducer{
{
Aggregate: instance.AggregateType,
EventReducers: []handler.EventReducer{
{
Event: instance.InstanceAddedEventType,
Reduce: p.reduceInstanceAdded,
},
{
Event: instance.InstanceDomainPrimarySetEventType,
Reduce: p.reduceInstanceDomainPrimarySet,
},
{
Event: instance.InstanceRemovedEventType,
Reduce: p.reduceInstanceRemoved,
},
},
},
{
Aggregate: project.AggregateType,
EventReducers: []handler.EventReducer{
{
Event: project.ProjectAddedType,
Reduce: p.reduceProjectAdded,
},
{
Event: project.ApplicationAddedType,
Reduce: p.reduceApplicationAdded,
},
{
Event: project.OIDCConfigAddedType,
Reduce: p.reduceOIDCConfigAdded,
},
{
Event: project.APIConfigAddedType,
Reduce: p.reduceAPIConfigAdded,
},
},
},
{
Aggregate: oidcsession.AggregateType,
EventReducers: []handler.EventReducer{
{
Event: oidcsession.AddedType,
Reduce: p.reduceOIDCSessionAdded,
},
},
},
{
Aggregate: milestone.AggregateType,
EventReducers: []handler.EventReducer{
{
Event: milestone.ReachedEventType,
Reduce: p.reduceReached,
},
{
Event: milestone.PushedEventType,
Reduce: p.reduceMilestonePushed,
Reduce: p.reducePushed,
},
},
},
}
}
func (p *milestoneProjection) reduceInstanceAdded(event eventstore.Event) (*handler.Statement, error) {
e, err := assertEvent[*instance.InstanceAddedEvent](event)
func (p *milestoneProjection) reduceReached(event eventstore.Event) (*handler.Statement, error) {
e, err := assertEvent[*milestone.ReachedEvent](event)
if err != nil {
return nil, err
}
allTypes := milestone.AllTypes()
statements := make([]func(eventstore.Event) handler.Exec, 0, len(allTypes))
for _, msType := range allTypes {
createColumns := []handler.Column{
handler.NewCol(MilestoneColumnInstanceID, e.Aggregate().InstanceID),
handler.NewCol(MilestoneColumnType, msType),
}
if msType == milestone.InstanceCreated {
createColumns = append(createColumns, handler.NewCol(MilestoneColumnReachedDate, event.CreatedAt()))
}
statements = append(statements, handler.AddCreateStatement(createColumns))
}
return handler.NewMultiStatement(e, statements...), nil
return handler.NewCreateStatement(event, []handler.Column{
handler.NewCol(MilestoneColumnInstanceID, e.Agg.InstanceID),
handler.NewCol(MilestoneColumnType, e.MilestoneType),
handler.NewCol(MilestoneColumnReachedDate, e.GetReachedDate()),
}), nil
}
func (p *milestoneProjection) reduceInstanceDomainPrimarySet(event eventstore.Event) (*handler.Statement, error) {
e, err := assertEvent[*instance.DomainPrimarySetEvent](event)
if err != nil {
return nil, err
}
return handler.NewUpdateStatement(
e,
[]handler.Column{
handler.NewCol(MilestoneColumnPrimaryDomain, e.Domain),
},
[]handler.Condition{
handler.NewCond(MilestoneColumnInstanceID, e.Aggregate().InstanceID),
handler.NewIsNullCond(MilestoneColumnPushedDate),
},
), nil
}
func (p *milestoneProjection) reduceProjectAdded(event eventstore.Event) (*handler.Statement, error) {
if _, err := assertEvent[*project.ProjectAddedEvent](event); err != nil {
return nil, err
}
return p.reduceReachedIfUserEventFunc(milestone.ProjectCreated)(event)
}
func (p *milestoneProjection) reduceApplicationAdded(event eventstore.Event) (*handler.Statement, error) {
if _, err := assertEvent[*project.ApplicationAddedEvent](event); err != nil {
return nil, err
}
return p.reduceReachedIfUserEventFunc(milestone.ApplicationCreated)(event)
}
func (p *milestoneProjection) reduceOIDCConfigAdded(event eventstore.Event) (*handler.Statement, error) {
e, err := assertEvent[*project.OIDCConfigAddedEvent](event)
if err != nil {
return nil, err
}
return p.reduceAppConfigAdded(e, e.ClientID)
}
func (p *milestoneProjection) reduceAPIConfigAdded(event eventstore.Event) (*handler.Statement, error) {
e, err := assertEvent[*project.APIConfigAddedEvent](event)
if err != nil {
return nil, err
}
return p.reduceAppConfigAdded(e, e.ClientID)
}
func (p *milestoneProjection) reduceOIDCSessionAdded(event eventstore.Event) (*handler.Statement, error) {
e, err := assertEvent[*oidcsession.AddedEvent](event)
if err != nil {
return nil, err
}
statements := []func(eventstore.Event) handler.Exec{
handler.AddUpdateStatement(
[]handler.Column{
handler.NewCol(MilestoneColumnReachedDate, event.CreatedAt()),
},
[]handler.Condition{
handler.NewCond(MilestoneColumnInstanceID, event.Aggregate().InstanceID),
handler.NewCond(MilestoneColumnType, milestone.AuthenticationSucceededOnInstance),
handler.NewIsNullCond(MilestoneColumnReachedDate),
},
),
}
// We ignore authentications without app, for example JWT profile or PAT
if e.ClientID != "" {
statements = append(statements, handler.AddUpdateStatement(
[]handler.Column{
handler.NewCol(MilestoneColumnReachedDate, event.CreatedAt()),
},
[]handler.Condition{
handler.NewCond(MilestoneColumnInstanceID, event.Aggregate().InstanceID),
handler.NewCond(MilestoneColumnType, milestone.AuthenticationSucceededOnApplication),
handler.Not(handler.NewTextArrayContainsCond(MilestoneColumnIgnoreClientIDs, e.ClientID)),
handler.NewIsNullCond(MilestoneColumnReachedDate),
},
))
}
return handler.NewMultiStatement(e, statements...), nil
}
func (p *milestoneProjection) reduceInstanceRemoved(event eventstore.Event) (*handler.Statement, error) {
if _, err := assertEvent[*instance.InstanceRemovedEvent](event); err != nil {
return nil, err
}
return p.reduceReachedFunc(milestone.InstanceDeleted)(event)
}
func (p *milestoneProjection) reduceMilestonePushed(event eventstore.Event) (*handler.Statement, error) {
func (p *milestoneProjection) reducePushed(event eventstore.Event) (*handler.Statement, error) {
e, err := assertEvent[*milestone.PushedEvent](event)
if err != nil {
return nil, err
}
if e.Agg.Version != milestone.AggregateVersion {
return handler.NewNoOpStatement(event), nil // Skip v1 events.
}
if e.MilestoneType != milestone.InstanceDeleted {
return handler.NewUpdateStatement(
event,
[]handler.Column{
handler.NewCol(MilestoneColumnPushedDate, event.CreatedAt()),
handler.NewCol(MilestoneColumnPushedDate, e.GetPushedDate()),
},
[]handler.Condition{
handler.NewCond(MilestoneColumnInstanceID, event.Aggregate().InstanceID),
@ -246,58 +99,3 @@ func (p *milestoneProjection) reduceMilestonePushed(event eventstore.Event) (*ha
},
), nil
}
func (p *milestoneProjection) reduceReachedIfUserEventFunc(msType milestone.Type) func(event eventstore.Event) (*handler.Statement, error) {
return func(event eventstore.Event) (*handler.Statement, error) {
if p.isSystemEvent(event) {
return handler.NewNoOpStatement(event), nil
}
return p.reduceReachedFunc(msType)(event)
}
}
func (p *milestoneProjection) reduceReachedFunc(msType milestone.Type) func(event eventstore.Event) (*handler.Statement, error) {
return func(event eventstore.Event) (*handler.Statement, error) {
return handler.NewUpdateStatement(event, []handler.Column{
handler.NewCol(MilestoneColumnReachedDate, event.CreatedAt()),
},
[]handler.Condition{
handler.NewCond(MilestoneColumnInstanceID, event.Aggregate().InstanceID),
handler.NewCond(MilestoneColumnType, msType),
handler.NewIsNullCond(MilestoneColumnReachedDate),
}), nil
}
}
func (p *milestoneProjection) reduceAppConfigAdded(event eventstore.Event, clientID string) (*handler.Statement, error) {
if !p.isSystemEvent(event) {
return handler.NewNoOpStatement(event), nil
}
return handler.NewUpdateStatement(
event,
[]handler.Column{
handler.NewArrayAppendCol(MilestoneColumnIgnoreClientIDs, clientID),
},
[]handler.Condition{
handler.NewCond(MilestoneColumnInstanceID, event.Aggregate().InstanceID),
handler.NewCond(MilestoneColumnType, milestone.AuthenticationSucceededOnApplication),
handler.NewIsNullCond(MilestoneColumnReachedDate),
},
), nil
}
func (p *milestoneProjection) isSystemEvent(event eventstore.Event) bool {
if userId, err := strconv.Atoi(event.Creator()); err == nil && userId > 0 {
return false
}
// check if it is a hard coded event creator
for _, creator := range []string{"", "system", "OIDC", "LOGIN", "SYSTEM"} {
if creator == event.Creator() {
return true
}
}
_, ok := p.systemUsers[event.Creator()]
return ok
}

View File

@ -4,13 +4,11 @@ import (
"testing"
"time"
"github.com/zitadel/zitadel/internal/database"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
"github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/milestone"
"github.com/zitadel/zitadel/internal/repository/oidcsession"
"github.com/zitadel/zitadel/internal/repository/project"
"github.com/zitadel/zitadel/internal/zerrors"
)
@ -19,6 +17,8 @@ func TestMilestonesProjection_reduces(t *testing.T) {
event func(t *testing.T) eventstore.Event
}
now := time.Now()
date, err := time.Parse(time.RFC3339, "2006-01-02T15:04:05Z")
require.NoError(t, err)
tests := []struct {
name string
args args
@ -29,292 +29,54 @@ func TestMilestonesProjection_reduces(t *testing.T) {
name: "reduceInstanceAdded",
args: args{
event: getEvent(timedTestEvent(
instance.InstanceAddedEventType,
instance.AggregateType,
[]byte(`{}`),
milestone.ReachedEventType,
milestone.AggregateType,
[]byte(`{"type": "instance_created"}`),
now,
), instance.InstanceAddedEventMapper),
withVersion(milestone.AggregateVersion),
), milestone.ReachedEventMapper),
},
reduce: (&milestoneProjection{}).reduceInstanceAdded,
reduce: (&milestoneProjection{}).reduceReached,
want: wantReduce{
aggregateType: eventstore.AggregateType("instance"),
aggregateType: eventstore.AggregateType("milestone"),
sequence: 15,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "INSERT INTO projections.milestones (instance_id, type, reached_date) VALUES ($1, $2, $3)",
expectedStmt: "INSERT INTO projections.milestones2 (instance_id, type, reached_date) VALUES ($1, $2, $3)",
expectedArgs: []interface{}{
"instance-id",
milestone.InstanceCreated,
now,
},
},
{
expectedStmt: "INSERT INTO projections.milestones (instance_id, type) VALUES ($1, $2)",
expectedArgs: []interface{}{
"instance-id",
milestone.AuthenticationSucceededOnInstance,
},
},
{
expectedStmt: "INSERT INTO projections.milestones (instance_id, type) VALUES ($1, $2)",
expectedArgs: []interface{}{
"instance-id",
milestone.ProjectCreated,
},
},
{
expectedStmt: "INSERT INTO projections.milestones (instance_id, type) VALUES ($1, $2)",
expectedArgs: []interface{}{
"instance-id",
milestone.ApplicationCreated,
},
},
{
expectedStmt: "INSERT INTO projections.milestones (instance_id, type) VALUES ($1, $2)",
expectedArgs: []interface{}{
"instance-id",
milestone.AuthenticationSucceededOnApplication,
},
},
{
expectedStmt: "INSERT INTO projections.milestones (instance_id, type) VALUES ($1, $2)",
expectedArgs: []interface{}{
"instance-id",
milestone.InstanceDeleted,
},
},
},
},
},
},
{
name: "reduceInstancePrimaryDomainSet",
args: args{
event: getEvent(testEvent(
instance.InstanceDomainPrimarySetEventType,
instance.AggregateType,
[]byte(`{"domain": "my.domain"}`),
), instance.DomainPrimarySetEventMapper),
},
reduce: (&milestoneProjection{}).reduceInstanceDomainPrimarySet,
want: wantReduce{
aggregateType: eventstore.AggregateType("instance"),
sequence: 15,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.milestones SET primary_domain = $1 WHERE (instance_id = $2) AND (last_pushed_date IS NULL)",
expectedArgs: []interface{}{
"my.domain",
"instance-id",
},
},
},
},
},
},
{
name: "reduceProjectAdded",
name: "reduceInstanceAdded with reached date",
args: args{
event: getEvent(timedTestEvent(
project.ProjectAddedType,
project.AggregateType,
[]byte(`{}`),
milestone.ReachedEventType,
milestone.AggregateType,
[]byte(`{"type": "instance_created", "reachedDate":"2006-01-02T15:04:05Z"}`),
now,
), project.ProjectAddedEventMapper),
withVersion(milestone.AggregateVersion),
), milestone.ReachedEventMapper),
},
reduce: (&milestoneProjection{}).reduceProjectAdded,
reduce: (&milestoneProjection{}).reduceReached,
want: wantReduce{
aggregateType: eventstore.AggregateType("project"),
aggregateType: eventstore.AggregateType("milestone"),
sequence: 15,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.milestones SET reached_date = $1 WHERE (instance_id = $2) AND (type = $3) AND (reached_date IS NULL)",
expectedStmt: "INSERT INTO projections.milestones2 (instance_id, type, reached_date) VALUES ($1, $2, $3)",
expectedArgs: []interface{}{
now,
"instance-id",
milestone.ProjectCreated,
},
},
},
},
},
},
{
name: "reduceApplicationAdded",
args: args{
event: getEvent(timedTestEvent(
project.ApplicationAddedType,
project.AggregateType,
[]byte(`{}`),
now,
), project.ApplicationAddedEventMapper),
},
reduce: (&milestoneProjection{}).reduceApplicationAdded,
want: wantReduce{
aggregateType: eventstore.AggregateType("project"),
sequence: 15,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.milestones SET reached_date = $1 WHERE (instance_id = $2) AND (type = $3) AND (reached_date IS NULL)",
expectedArgs: []interface{}{
now,
"instance-id",
milestone.ApplicationCreated,
},
},
},
},
},
},
{
name: "reduceOIDCConfigAdded user event",
args: args{
event: getEvent(testEvent(
project.OIDCConfigAddedType,
project.AggregateType,
[]byte(`{}`),
), project.OIDCConfigAddedEventMapper),
},
reduce: (&milestoneProjection{}).reduceOIDCConfigAdded,
want: wantReduce{
aggregateType: eventstore.AggregateType("project"),
sequence: 15,
executer: &testExecuter{},
},
},
{
name: "reduceOIDCConfigAdded system event",
args: args{
event: getEvent(toSystemEvent(testEvent(
project.OIDCConfigAddedType,
project.AggregateType,
[]byte(`{"clientId": "client-id"}`),
)), project.OIDCConfigAddedEventMapper),
},
reduce: (&milestoneProjection{}).reduceOIDCConfigAdded,
want: wantReduce{
aggregateType: eventstore.AggregateType("project"),
sequence: 15,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.milestones SET ignore_client_ids = array_append(ignore_client_ids, $1) WHERE (instance_id = $2) AND (type = $3) AND (reached_date IS NULL)",
expectedArgs: []interface{}{
"client-id",
"instance-id",
milestone.AuthenticationSucceededOnApplication,
},
},
},
},
},
},
{
name: "reduceAPIConfigAdded user event",
args: args{
event: getEvent(testEvent(
project.APIConfigAddedType,
project.AggregateType,
[]byte(`{}`),
), project.APIConfigAddedEventMapper),
},
reduce: (&milestoneProjection{}).reduceAPIConfigAdded,
want: wantReduce{
aggregateType: eventstore.AggregateType("project"),
sequence: 15,
executer: &testExecuter{},
},
},
{
name: "reduceAPIConfigAdded system event",
args: args{
event: getEvent(toSystemEvent(testEvent(
project.APIConfigAddedType,
project.AggregateType,
[]byte(`{"clientId": "client-id"}`),
)), project.APIConfigAddedEventMapper),
},
reduce: (&milestoneProjection{}).reduceAPIConfigAdded,
want: wantReduce{
aggregateType: eventstore.AggregateType("project"),
sequence: 15,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.milestones SET ignore_client_ids = array_append(ignore_client_ids, $1) WHERE (instance_id = $2) AND (type = $3) AND (reached_date IS NULL)",
expectedArgs: []interface{}{
"client-id",
"instance-id",
milestone.AuthenticationSucceededOnApplication,
},
},
},
},
},
},
{
name: "reduceOIDCSessionAdded",
args: args{
event: getEvent(timedTestEvent(
oidcsession.AddedType,
oidcsession.AggregateType,
[]byte(`{"clientID": "client-id"}`),
now,
), eventstore.GenericEventMapper[oidcsession.AddedEvent]),
},
reduce: (&milestoneProjection{}).reduceOIDCSessionAdded,
want: wantReduce{
aggregateType: eventstore.AggregateType("oidc_session"),
sequence: 15,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.milestones SET reached_date = $1 WHERE (instance_id = $2) AND (type = $3) AND (reached_date IS NULL)",
expectedArgs: []interface{}{
now,
"instance-id",
milestone.AuthenticationSucceededOnInstance,
},
},
{
expectedStmt: "UPDATE projections.milestones SET reached_date = $1 WHERE (instance_id = $2) AND (type = $3) AND (NOT (ignore_client_ids @> $4)) AND (reached_date IS NULL)",
expectedArgs: []interface{}{
now,
"instance-id",
milestone.AuthenticationSucceededOnApplication,
database.TextArray[string]{"client-id"},
},
},
},
},
},
},
{
name: "reduceInstanceRemoved",
args: args{
event: getEvent(timedTestEvent(
instance.InstanceRemovedEventType,
instance.AggregateType,
[]byte(`{}`),
now,
), instance.InstanceRemovedEventMapper),
},
reduce: (&milestoneProjection{}).reduceInstanceRemoved,
want: wantReduce{
aggregateType: eventstore.AggregateType("instance"),
sequence: 15,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.milestones SET reached_date = $1 WHERE (instance_id = $2) AND (type = $3) AND (reached_date IS NULL)",
expectedArgs: []interface{}{
now,
"instance-id",
milestone.InstanceDeleted,
milestone.InstanceCreated,
date,
},
},
},
@ -327,18 +89,19 @@ func TestMilestonesProjection_reduces(t *testing.T) {
event: getEvent(timedTestEvent(
milestone.PushedEventType,
milestone.AggregateType,
[]byte(`{"type": "ProjectCreated"}`),
[]byte(`{"type": "project_created"}`),
now,
withVersion(milestone.AggregateVersion),
), milestone.PushedEventMapper),
},
reduce: (&milestoneProjection{}).reduceMilestonePushed,
reduce: (&milestoneProjection{}).reducePushed,
want: wantReduce{
aggregateType: eventstore.AggregateType("milestone"),
sequence: 15,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.milestones SET last_pushed_date = $1 WHERE (instance_id = $2) AND (type = $3)",
expectedStmt: "UPDATE projections.milestones2 SET last_pushed_date = $1 WHERE (instance_id = $2) AND (type = $3)",
expectedArgs: []interface{}{
now,
"instance-id",
@ -349,23 +112,53 @@ func TestMilestonesProjection_reduces(t *testing.T) {
},
},
},
{
name: "reduceMilestonePushed normal milestone with pushed date",
args: args{
event: getEvent(timedTestEvent(
milestone.PushedEventType,
milestone.AggregateType,
[]byte(`{"type": "project_created", "pushedDate":"2006-01-02T15:04:05Z"}`),
now,
withVersion(milestone.AggregateVersion),
), milestone.PushedEventMapper),
},
reduce: (&milestoneProjection{}).reducePushed,
want: wantReduce{
aggregateType: eventstore.AggregateType("milestone"),
sequence: 15,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "UPDATE projections.milestones2 SET last_pushed_date = $1 WHERE (instance_id = $2) AND (type = $3)",
expectedArgs: []interface{}{
date,
"instance-id",
milestone.ProjectCreated,
},
},
},
},
},
},
{
name: "reduceMilestonePushed instance deleted milestone",
args: args{
event: getEvent(testEvent(
milestone.PushedEventType,
milestone.AggregateType,
[]byte(`{"type": "InstanceDeleted"}`),
[]byte(`{"type": "instance_deleted"}`),
withVersion(milestone.AggregateVersion),
), milestone.PushedEventMapper),
},
reduce: (&milestoneProjection{}).reduceMilestonePushed,
reduce: (&milestoneProjection{}).reducePushed,
want: wantReduce{
aggregateType: eventstore.AggregateType("milestone"),
sequence: 15,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "DELETE FROM projections.milestones WHERE (instance_id = $1)",
expectedStmt: "DELETE FROM projections.milestones2 WHERE (instance_id = $1)",
expectedArgs: []interface{}{
"instance-id",
},

View File

@ -156,7 +156,7 @@ func Create(ctx context.Context, sqlClient *database.DB, es handler.EventStore,
DeviceAuthProjection = newDeviceAuthProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["device_auth"]))
SessionProjection = newSessionProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["sessions"]))
AuthRequestProjection = newAuthRequestProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["auth_requests"]))
MilestoneProjection = newMilestoneProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["milestones"]), systemUsers)
MilestoneProjection = newMilestoneProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["milestones"]))
QuotaProjection = newQuotaProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["quotas"]))
LimitsProjection = newLimitsProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["limits"]))
RestrictionsProjection = newRestrictionsProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["restrictions"]))

View File

@ -9,20 +9,23 @@ import (
const (
AggregateType = "milestone"
AggregateVersion = "v1"
AggregateVersion = "v2"
)
type Aggregate struct {
eventstore.Aggregate
}
func NewAggregate(ctx context.Context, id string) *Aggregate {
instanceID := authz.GetInstance(ctx).InstanceID()
func NewAggregate(ctx context.Context) *Aggregate {
return NewInstanceAggregate(authz.GetInstance(ctx).InstanceID())
}
func NewInstanceAggregate(instanceID string) *Aggregate {
return &Aggregate{
Aggregate: eventstore.Aggregate{
Type: AggregateType,
Version: AggregateVersion,
ID: id,
ID: instanceID,
ResourceOwner: instanceID,
InstanceID: instanceID,
},

View File

@ -2,23 +2,88 @@ package milestone
import (
"context"
"time"
"github.com/zitadel/zitadel/internal/eventstore"
)
//go:generate enumer -type Type -json -linecomment -transform=snake
type Type int
const (
eventTypePrefix = eventstore.EventType("milestone.")
PushedEventType = eventTypePrefix + "pushed"
InstanceCreated Type = iota
AuthenticationSucceededOnInstance
ProjectCreated
ApplicationCreated
AuthenticationSucceededOnApplication
InstanceDeleted
)
var _ eventstore.Command = (*PushedEvent)(nil)
const (
eventTypePrefix = "milestone."
ReachedEventType = eventTypePrefix + "reached"
PushedEventType = eventTypePrefix + "pushed"
)
type ReachedEvent struct {
*eventstore.BaseEvent `json:"-"`
MilestoneType Type `json:"type"`
ReachedDate *time.Time `json:"reachedDate,omitempty"` // Defaults to [eventstore.BaseEvent.Creation] when empty
}
// Payload implements eventstore.Command.
func (e *ReachedEvent) Payload() any {
return e
}
func (e *ReachedEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
return nil
}
func (e *ReachedEvent) SetBaseEvent(b *eventstore.BaseEvent) {
e.BaseEvent = b
}
func (e *ReachedEvent) GetReachedDate() time.Time {
if e.ReachedDate != nil {
return *e.ReachedDate
}
return e.Creation
}
func NewReachedEvent(
ctx context.Context,
aggregate *Aggregate,
typ Type,
) *ReachedEvent {
return NewReachedEventWithDate(ctx, aggregate, typ, nil)
}
// NewReachedEventWithDate creates a [ReachedEvent] with a fixed Reached Date.
func NewReachedEventWithDate(
ctx context.Context,
aggregate *Aggregate,
typ Type,
reachedDate *time.Time,
) *ReachedEvent {
return &ReachedEvent{
BaseEvent: eventstore.NewBaseEventForPush(
ctx,
&aggregate.Aggregate,
ReachedEventType,
),
MilestoneType: typ,
ReachedDate: reachedDate,
}
}
type PushedEvent struct {
*eventstore.BaseEvent `json:"-"`
MilestoneType Type `json:"type"`
ExternalDomain string `json:"externalDomain"`
PrimaryDomain string `json:"primaryDomain"`
Endpoints []string `json:"endpoints"`
MilestoneType Type `json:"type"`
ExternalDomain string `json:"externalDomain"`
PrimaryDomain string `json:"primaryDomain"`
Endpoints []string `json:"endpoints"`
PushedDate *time.Time `json:"pushedDate,omitempty"` // Defaults to [eventstore.BaseEvent.Creation] when empty
}
// Payload implements eventstore.Command.
@ -34,14 +99,31 @@ func (p *PushedEvent) SetBaseEvent(b *eventstore.BaseEvent) {
p.BaseEvent = b
}
var PushedEventMapper = eventstore.GenericEventMapper[PushedEvent]
func (e *PushedEvent) GetPushedDate() time.Time {
if e.PushedDate != nil {
return *e.PushedDate
}
return e.Creation
}
func NewPushedEvent(
ctx context.Context,
aggregate *Aggregate,
msType Type,
typ Type,
endpoints []string,
externalDomain, primaryDomain string,
externalDomain string,
) *PushedEvent {
return NewPushedEventWithDate(ctx, aggregate, typ, endpoints, externalDomain, nil)
}
// NewPushedEventWithDate creates a [PushedEvent] with a fixed Pushed Date.
func NewPushedEventWithDate(
ctx context.Context,
aggregate *Aggregate,
typ Type,
endpoints []string,
externalDomain string,
pushedDate *time.Time,
) *PushedEvent {
return &PushedEvent{
BaseEvent: eventstore.NewBaseEventForPush(
@ -49,9 +131,9 @@ func NewPushedEvent(
&aggregate.Aggregate,
PushedEventType,
),
MilestoneType: msType,
MilestoneType: typ,
Endpoints: endpoints,
ExternalDomain: externalDomain,
PrimaryDomain: primaryDomain,
PushedDate: pushedDate,
}
}

View File

@ -4,6 +4,12 @@ import (
"github.com/zitadel/zitadel/internal/eventstore"
)
var (
ReachedEventMapper = eventstore.GenericEventMapper[ReachedEvent]
PushedEventMapper = eventstore.GenericEventMapper[PushedEvent]
)
func init() {
eventstore.RegisterFilterEventMapper(AggregateType, ReachedEventType, ReachedEventMapper)
eventstore.RegisterFilterEventMapper(AggregateType, PushedEventType, PushedEventMapper)
}

View File

@ -1,59 +0,0 @@
//go:generate stringer -type Type
package milestone
import (
"fmt"
"strings"
)
type Type int
const (
unknown Type = iota
InstanceCreated
AuthenticationSucceededOnInstance
ProjectCreated
ApplicationCreated
AuthenticationSucceededOnApplication
InstanceDeleted
typesCount
)
func AllTypes() []Type {
types := make([]Type, typesCount-1)
for i := Type(1); i < typesCount; i++ {
types[i-1] = i
}
return types
}
func (t *Type) MarshalJSON() ([]byte, error) {
return []byte(fmt.Sprintf(`"%s"`, t.String())), nil
}
func (t *Type) UnmarshalJSON(data []byte) error {
*t = typeFromString(strings.Trim(string(data), `"`))
return nil
}
func typeFromString(t string) Type {
switch t {
case InstanceCreated.String():
return InstanceCreated
case AuthenticationSucceededOnInstance.String():
return AuthenticationSucceededOnInstance
case ProjectCreated.String():
return ProjectCreated
case ApplicationCreated.String():
return ApplicationCreated
case AuthenticationSucceededOnApplication.String():
return AuthenticationSucceededOnApplication
case InstanceDeleted.String():
return InstanceDeleted
default:
return unknown
}
}

View File

@ -0,0 +1,112 @@
// Code generated by "enumer -type Type -json -linecomment -transform=snake"; DO NOT EDIT.
package milestone
import (
"encoding/json"
"fmt"
"strings"
)
const _TypeName = "instance_createdauthentication_succeeded_on_instanceproject_createdapplication_createdauthentication_succeeded_on_applicationinstance_deleted"
var _TypeIndex = [...]uint8{0, 16, 52, 67, 86, 125, 141}
const _TypeLowerName = "instance_createdauthentication_succeeded_on_instanceproject_createdapplication_createdauthentication_succeeded_on_applicationinstance_deleted"
func (i Type) String() string {
if i < 0 || i >= Type(len(_TypeIndex)-1) {
return fmt.Sprintf("Type(%d)", i)
}
return _TypeName[_TypeIndex[i]:_TypeIndex[i+1]]
}
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
func _TypeNoOp() {
var x [1]struct{}
_ = x[InstanceCreated-(0)]
_ = x[AuthenticationSucceededOnInstance-(1)]
_ = x[ProjectCreated-(2)]
_ = x[ApplicationCreated-(3)]
_ = x[AuthenticationSucceededOnApplication-(4)]
_ = x[InstanceDeleted-(5)]
}
var _TypeValues = []Type{InstanceCreated, AuthenticationSucceededOnInstance, ProjectCreated, ApplicationCreated, AuthenticationSucceededOnApplication, InstanceDeleted}
var _TypeNameToValueMap = map[string]Type{
_TypeName[0:16]: InstanceCreated,
_TypeLowerName[0:16]: InstanceCreated,
_TypeName[16:52]: AuthenticationSucceededOnInstance,
_TypeLowerName[16:52]: AuthenticationSucceededOnInstance,
_TypeName[52:67]: ProjectCreated,
_TypeLowerName[52:67]: ProjectCreated,
_TypeName[67:86]: ApplicationCreated,
_TypeLowerName[67:86]: ApplicationCreated,
_TypeName[86:125]: AuthenticationSucceededOnApplication,
_TypeLowerName[86:125]: AuthenticationSucceededOnApplication,
_TypeName[125:141]: InstanceDeleted,
_TypeLowerName[125:141]: InstanceDeleted,
}
var _TypeNames = []string{
_TypeName[0:16],
_TypeName[16:52],
_TypeName[52:67],
_TypeName[67:86],
_TypeName[86:125],
_TypeName[125:141],
}
// TypeString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func TypeString(s string) (Type, error) {
if val, ok := _TypeNameToValueMap[s]; ok {
return val, nil
}
if val, ok := _TypeNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to Type values", s)
}
// TypeValues returns all values of the enum
func TypeValues() []Type {
return _TypeValues
}
// TypeStrings returns a slice of all String values of the enum
func TypeStrings() []string {
strs := make([]string, len(_TypeNames))
copy(strs, _TypeNames)
return strs
}
// IsAType returns "true" if the value is listed in the enum definition. "false" otherwise
func (i Type) IsAType() bool {
for _, v := range _TypeValues {
if i == v {
return true
}
}
return false
}
// MarshalJSON implements the json.Marshaler interface for Type
func (i Type) MarshalJSON() ([]byte, error) {
return json.Marshal(i.String())
}
// UnmarshalJSON implements the json.Unmarshaler interface for Type
func (i *Type) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return fmt.Errorf("Type should be a string, got %s", data)
}
var err error
*i, err = TypeString(s)
return err
}

View File

@ -1,30 +0,0 @@
// Code generated by "stringer -type Type"; DO NOT EDIT.
package milestone
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[unknown-0]
_ = x[InstanceCreated-1]
_ = x[AuthenticationSucceededOnInstance-2]
_ = x[ProjectCreated-3]
_ = x[ApplicationCreated-4]
_ = x[AuthenticationSucceededOnApplication-5]
_ = x[InstanceDeleted-6]
_ = x[typesCount-7]
}
const _Type_name = "unknownInstanceCreatedAuthenticationSucceededOnInstanceProjectCreatedApplicationCreatedAuthenticationSucceededOnApplicationInstanceDeletedtypesCount"
var _Type_index = [...]uint8{0, 7, 22, 55, 69, 87, 123, 138, 148}
func (i Type) String() string {
if i < 0 || i >= Type(len(_Type_index)-1) {
return "Type(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _Type_name[_Type_index[i]:_Type_index[i+1]]
}