mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 21:27:42 +00:00
perf(milestones): refactor (#8788)
Some checks are pending
ZITADEL CI/CD / core (push) Waiting to run
ZITADEL CI/CD / console (push) Waiting to run
ZITADEL CI/CD / version (push) Waiting to run
ZITADEL CI/CD / compile (push) Blocked by required conditions
ZITADEL CI/CD / core-unit-test (push) Blocked by required conditions
ZITADEL CI/CD / core-integration-test (push) Blocked by required conditions
ZITADEL CI/CD / lint (push) Blocked by required conditions
ZITADEL CI/CD / container (push) Blocked by required conditions
ZITADEL CI/CD / e2e (push) Blocked by required conditions
ZITADEL CI/CD / release (push) Blocked by required conditions
Code Scanning / CodeQL-Build (go) (push) Waiting to run
Code Scanning / CodeQL-Build (javascript) (push) Waiting to run
Some checks are pending
ZITADEL CI/CD / core (push) Waiting to run
ZITADEL CI/CD / console (push) Waiting to run
ZITADEL CI/CD / version (push) Waiting to run
ZITADEL CI/CD / compile (push) Blocked by required conditions
ZITADEL CI/CD / core-unit-test (push) Blocked by required conditions
ZITADEL CI/CD / core-integration-test (push) Blocked by required conditions
ZITADEL CI/CD / lint (push) Blocked by required conditions
ZITADEL CI/CD / container (push) Blocked by required conditions
ZITADEL CI/CD / e2e (push) Blocked by required conditions
ZITADEL CI/CD / release (push) Blocked by required conditions
Code Scanning / CodeQL-Build (go) (push) Waiting to run
Code Scanning / CodeQL-Build (javascript) (push) Waiting to run
# Which Problems Are Solved Milestones used existing events from a number of aggregates. OIDC session is one of them. We noticed in load-tests that the reduction of the oidc_session.added event into the milestone projection is a costly business with payload based conditionals. A milestone is reached once, but even then we remain subscribed to the OIDC events. This requires the projections.current_states to be updated continuously. # How the Problems Are Solved The milestone creation is refactored to use dedicated events instead. The command side decides when a milestone is reached and creates the reached event once for each milestone when required. # Additional Changes In order to prevent reached milestones being created twice, a migration script is provided. When the old `projections.milestones` table exist, the state is read from there and `v2` milestone aggregate events are created, with the original reached and pushed dates. # Additional Context - Closes https://github.com/zitadel/zitadel/issues/8800
This commit is contained in:
82
internal/command/cache.go
Normal file
82
internal/command/cache.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/cache"
|
||||
"github.com/zitadel/zitadel/internal/cache/gomap"
|
||||
"github.com/zitadel/zitadel/internal/cache/noop"
|
||||
"github.com/zitadel/zitadel/internal/cache/pg"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
)
|
||||
|
||||
type Caches struct {
|
||||
connectors *cacheConnectors
|
||||
milestones cache.Cache[milestoneIndex, string, *MilestonesReached]
|
||||
}
|
||||
|
||||
func startCaches(background context.Context, conf *cache.CachesConfig, client *database.DB) (_ *Caches, err error) {
|
||||
caches := &Caches{
|
||||
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
|
||||
}
|
||||
if conf == nil {
|
||||
return caches, nil
|
||||
}
|
||||
caches.connectors, err = startCacheConnectors(background, conf, client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
caches.milestones, err = startCache[milestoneIndex, string, *MilestonesReached](background, []milestoneIndex{milestoneIndexInstanceID}, "milestones", conf.Instance, caches.connectors)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return caches, nil
|
||||
}
|
||||
|
||||
type cacheConnectors struct {
|
||||
memory *cache.AutoPruneConfig
|
||||
postgres *pgxPoolCacheConnector
|
||||
}
|
||||
|
||||
type pgxPoolCacheConnector struct {
|
||||
*cache.AutoPruneConfig
|
||||
client *database.DB
|
||||
}
|
||||
|
||||
func startCacheConnectors(_ context.Context, conf *cache.CachesConfig, client *database.DB) (_ *cacheConnectors, err error) {
|
||||
connectors := new(cacheConnectors)
|
||||
if conf.Connectors.Memory.Enabled {
|
||||
connectors.memory = &conf.Connectors.Memory.AutoPrune
|
||||
}
|
||||
if conf.Connectors.Postgres.Enabled {
|
||||
connectors.postgres = &pgxPoolCacheConnector{
|
||||
AutoPruneConfig: &conf.Connectors.Postgres.AutoPrune,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
return connectors, nil
|
||||
}
|
||||
|
||||
func startCache[I ~int, K ~string, V cache.Entry[I, K]](background context.Context, indices []I, name string, conf *cache.CacheConfig, connectors *cacheConnectors) (cache.Cache[I, K, V], error) {
|
||||
if conf == nil || conf.Connector == "" {
|
||||
return noop.NewCache[I, K, V](), nil
|
||||
}
|
||||
if strings.EqualFold(conf.Connector, "memory") && connectors.memory != nil {
|
||||
c := gomap.NewCache[I, K, V](background, indices, *conf)
|
||||
connectors.memory.StartAutoPrune(background, c, name)
|
||||
return c, nil
|
||||
}
|
||||
if strings.EqualFold(conf.Connector, "postgres") && connectors.postgres != nil {
|
||||
client := connectors.postgres.client
|
||||
c, err := pg.NewCache[I, K, V](background, name, *conf, indices, client.Pool, client.Type())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query start cache: %w", err)
|
||||
}
|
||||
connectors.postgres.StartAutoPrune(background, c, name)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("cache connector %q not enabled", conf.Connector)
|
||||
}
|
@@ -18,6 +18,7 @@ import (
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
api_http "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/cache"
|
||||
"github.com/zitadel/zitadel/internal/command/preparation"
|
||||
sd "github.com/zitadel/zitadel/internal/config/systemdefaults"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
@@ -88,10 +89,17 @@ type Commands struct {
|
||||
EventGroupExisting func(group string) bool
|
||||
|
||||
GenerateDomain func(instanceName, domain string) (string, error)
|
||||
|
||||
caches *Caches
|
||||
// Store instance IDs where all milestones are reached (except InstanceDeleted).
|
||||
// These instance's milestones never need to be invalidated,
|
||||
// so the query and cache overhead can completely eliminated.
|
||||
milestonesCompleted sync.Map
|
||||
}
|
||||
|
||||
func StartCommands(
|
||||
es *eventstore.Eventstore,
|
||||
cachesConfig *cache.CachesConfig,
|
||||
defaults sd.SystemDefaults,
|
||||
zitadelRoles []authz.RoleMapping,
|
||||
staticStore static.Storage,
|
||||
@@ -123,6 +131,10 @@ func StartCommands(
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("password hasher: %w", err)
|
||||
}
|
||||
caches, err := startCaches(context.TODO(), cachesConfig, es.Client())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("caches: %w", err)
|
||||
}
|
||||
repo = &Commands{
|
||||
eventstore: es,
|
||||
static: staticStore,
|
||||
@@ -176,6 +188,7 @@ func StartCommands(
|
||||
},
|
||||
},
|
||||
GenerateDomain: domain.NewGeneratedInstanceDomain,
|
||||
caches: caches,
|
||||
}
|
||||
|
||||
if defaultSecretGenerators != nil && defaultSecretGenerators.ClientSecret != nil {
|
||||
|
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/command/preparation"
|
||||
@@ -17,6 +18,7 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/notification/channels/smtp"
|
||||
"github.com/zitadel/zitadel/internal/repository/instance"
|
||||
"github.com/zitadel/zitadel/internal/repository/limits"
|
||||
"github.com/zitadel/zitadel/internal/repository/milestone"
|
||||
"github.com/zitadel/zitadel/internal/repository/org"
|
||||
"github.com/zitadel/zitadel/internal/repository/project"
|
||||
"github.com/zitadel/zitadel/internal/repository/quota"
|
||||
@@ -292,7 +294,7 @@ func setUpInstance(ctx context.Context, c *Commands, setup *InstanceSetup) (vali
|
||||
setupFeatures(&validations, setup.Features, setup.zitadel.instanceID)
|
||||
setupLimits(c, &validations, limits.NewAggregate(setup.zitadel.limitsID, setup.zitadel.instanceID), setup.Limits)
|
||||
setupRestrictions(c, &validations, restrictions.NewAggregate(setup.zitadel.restrictionsID, setup.zitadel.instanceID, setup.zitadel.instanceID), setup.Restrictions)
|
||||
|
||||
setupInstanceCreatedMilestone(&validations, setup.zitadel.instanceID)
|
||||
return validations, pat, machineKey, nil
|
||||
}
|
||||
|
||||
@@ -890,7 +892,8 @@ func (c *Commands) RemoveInstance(ctx context.Context, id string) (*domain.Objec
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = c.caches.milestones.Invalidate(ctx, milestoneIndexInstanceID, id)
|
||||
logging.OnError(err).Error("milestone invalidate")
|
||||
return &domain.ObjectDetails{
|
||||
Sequence: events[len(events)-1].Sequence(),
|
||||
EventDate: events[len(events)-1].CreatedAt(),
|
||||
@@ -908,10 +911,16 @@ func (c *Commands) prepareRemoveInstance(a *instance.Aggregate) preparation.Vali
|
||||
if !writeModel.State.Exists() {
|
||||
return nil, zerrors.ThrowNotFound(err, "COMMA-AE3GS", "Errors.Instance.NotFound")
|
||||
}
|
||||
return []eventstore.Command{instance.NewInstanceRemovedEvent(ctx,
|
||||
&a.Aggregate,
|
||||
writeModel.Name,
|
||||
writeModel.Domains)},
|
||||
milestoneAggregate := milestone.NewInstanceAggregate(a.ID)
|
||||
return []eventstore.Command{
|
||||
instance.NewInstanceRemovedEvent(ctx,
|
||||
&a.Aggregate,
|
||||
writeModel.Name,
|
||||
writeModel.Domains),
|
||||
milestone.NewReachedEvent(ctx,
|
||||
milestoneAggregate,
|
||||
milestone.InstanceDeleted),
|
||||
},
|
||||
nil
|
||||
}, nil
|
||||
}
|
||||
|
@@ -13,6 +13,7 @@ import (
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/cache/noop"
|
||||
"github.com/zitadel/zitadel/internal/command/preparation"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
@@ -20,6 +21,7 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
id_mock "github.com/zitadel/zitadel/internal/id/mock"
|
||||
"github.com/zitadel/zitadel/internal/repository/instance"
|
||||
"github.com/zitadel/zitadel/internal/repository/milestone"
|
||||
"github.com/zitadel/zitadel/internal/repository/org"
|
||||
"github.com/zitadel/zitadel/internal/repository/project"
|
||||
"github.com/zitadel/zitadel/internal/repository/user"
|
||||
@@ -372,6 +374,7 @@ func setupInstanceEvents(ctx context.Context, instanceID, orgID, projectID, appI
|
||||
setupInstanceElementsEvents(ctx, instanceID, instanceName, defaultLanguage),
|
||||
orgEvents(ctx, instanceID, orgID, orgName, projectID, domain, externalSecure, true, true),
|
||||
generatedDomainEvents(ctx, instanceID, orgID, projectID, appID, domain),
|
||||
instanceCreatedMilestoneEvent(ctx, instanceID),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -401,6 +404,12 @@ func generatedDomainEvents(ctx context.Context, instanceID, orgID, projectID, ap
|
||||
}
|
||||
}
|
||||
|
||||
func instanceCreatedMilestoneEvent(ctx context.Context, instanceID string) []eventstore.Command {
|
||||
return []eventstore.Command{
|
||||
milestone.NewReachedEvent(ctx, milestone.NewInstanceAggregate(instanceID), milestone.InstanceCreated),
|
||||
}
|
||||
}
|
||||
|
||||
func generatedDomainFilters(instanceID, orgID, projectID, appID, generatedDomain string) []expect {
|
||||
return []expect{
|
||||
expectFilter(),
|
||||
@@ -1378,7 +1387,7 @@ func TestCommandSide_UpdateInstance(t *testing.T) {
|
||||
|
||||
func TestCommandSide_RemoveInstance(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
eventstore func(t *testing.T) *eventstore.Eventstore
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
@@ -1397,8 +1406,7 @@ func TestCommandSide_RemoveInstance(t *testing.T) {
|
||||
{
|
||||
name: "instance not existing, not found error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
@@ -1413,8 +1421,7 @@ func TestCommandSide_RemoveInstance(t *testing.T) {
|
||||
{
|
||||
name: "instance removed, not found error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
instance.NewInstanceAddedEvent(
|
||||
@@ -1444,8 +1451,7 @@ func TestCommandSide_RemoveInstance(t *testing.T) {
|
||||
{
|
||||
name: "instance remove, ok",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusherWithInstanceID(
|
||||
"INSTANCE",
|
||||
@@ -1480,6 +1486,10 @@ func TestCommandSide_RemoveInstance(t *testing.T) {
|
||||
"custom.domain",
|
||||
},
|
||||
),
|
||||
milestone.NewReachedEvent(context.Background(),
|
||||
milestone.NewInstanceAggregate("INSTANCE"),
|
||||
milestone.InstanceDeleted,
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
@@ -1497,7 +1507,10 @@ func TestCommandSide_RemoveInstance(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
caches: &Caches{
|
||||
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
|
||||
},
|
||||
}
|
||||
got, err := r.RemoveInstance(tt.args.ctx, tt.args.instanceID)
|
||||
if tt.res.err == nil {
|
||||
|
@@ -3,20 +3,176 @@ package command
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/command/preparation"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/milestone"
|
||||
)
|
||||
|
||||
// MilestonePushed writes a new milestone.PushedEvent with a new milestone.Aggregate to the eventstore
|
||||
type milestoneIndex int
|
||||
|
||||
const (
|
||||
milestoneIndexInstanceID milestoneIndex = iota
|
||||
)
|
||||
|
||||
type MilestonesReached struct {
|
||||
InstanceID string
|
||||
InstanceCreated bool
|
||||
AuthenticationSucceededOnInstance bool
|
||||
ProjectCreated bool
|
||||
ApplicationCreated bool
|
||||
AuthenticationSucceededOnApplication bool
|
||||
InstanceDeleted bool
|
||||
}
|
||||
|
||||
// complete returns true if all milestones except InstanceDeleted are reached.
|
||||
func (m *MilestonesReached) complete() bool {
|
||||
return m.InstanceCreated &&
|
||||
m.AuthenticationSucceededOnInstance &&
|
||||
m.ProjectCreated &&
|
||||
m.ApplicationCreated &&
|
||||
m.AuthenticationSucceededOnApplication
|
||||
}
|
||||
|
||||
// GetMilestonesReached finds the milestone state for the current instance.
|
||||
func (c *Commands) GetMilestonesReached(ctx context.Context) (*MilestonesReached, error) {
|
||||
milestones, ok := c.getCachedMilestonesReached(ctx)
|
||||
if ok {
|
||||
return milestones, nil
|
||||
}
|
||||
model := NewMilestonesReachedWriteModel(authz.GetInstance(ctx).InstanceID())
|
||||
if err := c.eventstore.FilterToQueryReducer(ctx, model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
milestones = &model.MilestonesReached
|
||||
c.setCachedMilestonesReached(ctx, milestones)
|
||||
return milestones, nil
|
||||
}
|
||||
|
||||
// getCachedMilestonesReached checks for milestone completeness on an instance and returns a filled
|
||||
// [MilestonesReached] object.
|
||||
// Otherwise it looks for the object in the milestone cache.
|
||||
func (c *Commands) getCachedMilestonesReached(ctx context.Context) (*MilestonesReached, bool) {
|
||||
instanceID := authz.GetInstance(ctx).InstanceID()
|
||||
if _, ok := c.milestonesCompleted.Load(instanceID); ok {
|
||||
return &MilestonesReached{
|
||||
InstanceID: instanceID,
|
||||
InstanceCreated: true,
|
||||
AuthenticationSucceededOnInstance: true,
|
||||
ProjectCreated: true,
|
||||
ApplicationCreated: true,
|
||||
AuthenticationSucceededOnApplication: true,
|
||||
InstanceDeleted: false,
|
||||
}, ok
|
||||
}
|
||||
return c.caches.milestones.Get(ctx, milestoneIndexInstanceID, instanceID)
|
||||
}
|
||||
|
||||
// setCachedMilestonesReached stores the current milestones state in the milestones cache.
|
||||
// If the milestones are complete, the instance ID is stored in milestonesCompleted instead.
|
||||
func (c *Commands) setCachedMilestonesReached(ctx context.Context, milestones *MilestonesReached) {
|
||||
if milestones.complete() {
|
||||
c.milestonesCompleted.Store(milestones.InstanceID, struct{}{})
|
||||
return
|
||||
}
|
||||
c.caches.milestones.Set(ctx, milestones)
|
||||
}
|
||||
|
||||
// Keys implements cache.Entry
|
||||
func (c *MilestonesReached) Keys(i milestoneIndex) []string {
|
||||
if i == milestoneIndexInstanceID {
|
||||
return []string{c.InstanceID}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MilestonePushed writes a new milestone.PushedEvent with the milestone.Aggregate to the eventstore
|
||||
func (c *Commands) MilestonePushed(
|
||||
ctx context.Context,
|
||||
instanceID string,
|
||||
msType milestone.Type,
|
||||
endpoints []string,
|
||||
primaryDomain string,
|
||||
) error {
|
||||
id, err := c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = c.eventstore.Push(ctx, milestone.NewPushedEvent(ctx, milestone.NewAggregate(ctx, id), msType, endpoints, primaryDomain, c.externalDomain))
|
||||
_, err := c.eventstore.Push(ctx, milestone.NewPushedEvent(ctx, milestone.NewInstanceAggregate(instanceID), msType, endpoints, c.externalDomain))
|
||||
return err
|
||||
}
|
||||
|
||||
func setupInstanceCreatedMilestone(validations *[]preparation.Validation, instanceID string) {
|
||||
*validations = append(*validations, func() (preparation.CreateCommands, error) {
|
||||
return func(ctx context.Context, _ preparation.FilterToQueryReducer) ([]eventstore.Command, error) {
|
||||
return []eventstore.Command{
|
||||
milestone.NewReachedEvent(ctx, milestone.NewInstanceAggregate(instanceID), milestone.InstanceCreated),
|
||||
}, nil
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *OIDCSessionEvents) SetMilestones(ctx context.Context, clientID string, isHuman bool) (postCommit func(ctx context.Context), err error) {
|
||||
postCommit = func(ctx context.Context) {}
|
||||
milestones, err := s.commands.GetMilestonesReached(ctx)
|
||||
if err != nil {
|
||||
return postCommit, err
|
||||
}
|
||||
|
||||
instance := authz.GetInstance(ctx)
|
||||
aggregate := milestone.NewAggregate(ctx)
|
||||
var invalidate bool
|
||||
if !milestones.AuthenticationSucceededOnInstance {
|
||||
s.events = append(s.events, milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance))
|
||||
invalidate = true
|
||||
}
|
||||
if !milestones.AuthenticationSucceededOnApplication && isHuman && clientID != instance.ConsoleClientID() {
|
||||
s.events = append(s.events, milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication))
|
||||
invalidate = true
|
||||
}
|
||||
if invalidate {
|
||||
postCommit = s.commands.invalidateMilestoneCachePostCommit(instance.InstanceID())
|
||||
}
|
||||
return postCommit, nil
|
||||
}
|
||||
|
||||
func (c *Commands) projectCreatedMilestone(ctx context.Context, cmds *[]eventstore.Command) (postCommit func(ctx context.Context), err error) {
|
||||
postCommit = func(ctx context.Context) {}
|
||||
if isSystemUser(ctx) {
|
||||
return postCommit, nil
|
||||
}
|
||||
milestones, err := c.GetMilestonesReached(ctx)
|
||||
if err != nil {
|
||||
return postCommit, err
|
||||
}
|
||||
if milestones.ProjectCreated {
|
||||
return postCommit, nil
|
||||
}
|
||||
aggregate := milestone.NewAggregate(ctx)
|
||||
*cmds = append(*cmds, milestone.NewReachedEvent(ctx, aggregate, milestone.ProjectCreated))
|
||||
return c.invalidateMilestoneCachePostCommit(aggregate.InstanceID), nil
|
||||
}
|
||||
|
||||
func (c *Commands) applicationCreatedMilestone(ctx context.Context, cmds *[]eventstore.Command) (postCommit func(ctx context.Context), err error) {
|
||||
postCommit = func(ctx context.Context) {}
|
||||
if isSystemUser(ctx) {
|
||||
return postCommit, nil
|
||||
}
|
||||
milestones, err := c.GetMilestonesReached(ctx)
|
||||
if err != nil {
|
||||
return postCommit, err
|
||||
}
|
||||
if milestones.ApplicationCreated {
|
||||
return postCommit, nil
|
||||
}
|
||||
aggregate := milestone.NewAggregate(ctx)
|
||||
*cmds = append(*cmds, milestone.NewReachedEvent(ctx, aggregate, milestone.ApplicationCreated))
|
||||
return c.invalidateMilestoneCachePostCommit(aggregate.InstanceID), nil
|
||||
}
|
||||
|
||||
func (c *Commands) invalidateMilestoneCachePostCommit(instanceID string) func(ctx context.Context) {
|
||||
return func(ctx context.Context) {
|
||||
err := c.caches.milestones.Invalidate(ctx, milestoneIndexInstanceID, instanceID)
|
||||
logging.WithFields("instance_id", instanceID).OnError(err).Error("failed to invalidate milestone cache")
|
||||
}
|
||||
}
|
||||
|
||||
func isSystemUser(ctx context.Context) bool {
|
||||
return authz.GetCtxData(ctx).SystemMemberships != nil
|
||||
}
|
||||
|
58
internal/command/milestone_model.go
Normal file
58
internal/command/milestone_model.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/milestone"
|
||||
)
|
||||
|
||||
type MilestonesReachedWriteModel struct {
|
||||
eventstore.WriteModel
|
||||
MilestonesReached
|
||||
}
|
||||
|
||||
func NewMilestonesReachedWriteModel(instanceID string) *MilestonesReachedWriteModel {
|
||||
return &MilestonesReachedWriteModel{
|
||||
WriteModel: eventstore.WriteModel{
|
||||
AggregateID: instanceID,
|
||||
InstanceID: instanceID,
|
||||
},
|
||||
MilestonesReached: MilestonesReached{
|
||||
InstanceID: instanceID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MilestonesReachedWriteModel) Query() *eventstore.SearchQueryBuilder {
|
||||
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
AddQuery().
|
||||
AggregateTypes(milestone.AggregateType).
|
||||
AggregateIDs(m.AggregateID).
|
||||
EventTypes(milestone.ReachedEventType, milestone.PushedEventType).
|
||||
Builder()
|
||||
}
|
||||
|
||||
func (m *MilestonesReachedWriteModel) Reduce() error {
|
||||
for _, event := range m.Events {
|
||||
if e, ok := event.(*milestone.ReachedEvent); ok {
|
||||
m.reduceReachedEvent(e)
|
||||
}
|
||||
}
|
||||
return m.WriteModel.Reduce()
|
||||
}
|
||||
|
||||
func (m *MilestonesReachedWriteModel) reduceReachedEvent(e *milestone.ReachedEvent) {
|
||||
switch e.MilestoneType {
|
||||
case milestone.InstanceCreated:
|
||||
m.InstanceCreated = true
|
||||
case milestone.AuthenticationSucceededOnInstance:
|
||||
m.AuthenticationSucceededOnInstance = true
|
||||
case milestone.ProjectCreated:
|
||||
m.ProjectCreated = true
|
||||
case milestone.ApplicationCreated:
|
||||
m.ApplicationCreated = true
|
||||
case milestone.AuthenticationSucceededOnApplication:
|
||||
m.AuthenticationSucceededOnApplication = true
|
||||
case milestone.InstanceDeleted:
|
||||
m.InstanceDeleted = true
|
||||
}
|
||||
}
|
629
internal/command/milestone_test.go
Normal file
629
internal/command/milestone_test.go
Normal file
@@ -0,0 +1,629 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/cache"
|
||||
"github.com/zitadel/zitadel/internal/cache/gomap"
|
||||
"github.com/zitadel/zitadel/internal/cache/noop"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/milestone"
|
||||
)
|
||||
|
||||
func TestCommands_GetMilestonesReached(t *testing.T) {
|
||||
cached := &MilestonesReached{
|
||||
InstanceID: "cached-id",
|
||||
InstanceCreated: true,
|
||||
AuthenticationSucceededOnInstance: true,
|
||||
}
|
||||
|
||||
ctx := authz.WithInstanceID(context.Background(), "instanceID")
|
||||
aggregate := milestone.NewAggregate(ctx)
|
||||
|
||||
type fields struct {
|
||||
eventstore func(*testing.T) *eventstore.Eventstore
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want *MilestonesReached
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "cached",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "cached-id"),
|
||||
},
|
||||
want: cached,
|
||||
},
|
||||
{
|
||||
name: "filter error",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilterError(io.ErrClosedPipe),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
},
|
||||
wantErr: io.ErrClosedPipe,
|
||||
},
|
||||
{
|
||||
name: "no events, all false",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
},
|
||||
want: &MilestonesReached{
|
||||
InstanceID: "instanceID",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "instance created",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.InstanceCreated)),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
},
|
||||
want: &MilestonesReached{
|
||||
InstanceID: "instanceID",
|
||||
InstanceCreated: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "instance auth",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
},
|
||||
want: &MilestonesReached{
|
||||
InstanceID: "instanceID",
|
||||
AuthenticationSucceededOnInstance: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "project created",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.ProjectCreated)),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
},
|
||||
want: &MilestonesReached{
|
||||
InstanceID: "instanceID",
|
||||
ProjectCreated: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "app created",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.ApplicationCreated)),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
},
|
||||
want: &MilestonesReached{
|
||||
InstanceID: "instanceID",
|
||||
ApplicationCreated: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "app auth",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication)),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
},
|
||||
want: &MilestonesReached{
|
||||
InstanceID: "instanceID",
|
||||
AuthenticationSucceededOnApplication: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "instance deleted",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.InstanceDeleted)),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
},
|
||||
want: &MilestonesReached{
|
||||
InstanceID: "instanceID",
|
||||
InstanceDeleted: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cache := gomap.NewCache[milestoneIndex, string, *MilestonesReached](
|
||||
context.Background(),
|
||||
[]milestoneIndex{milestoneIndexInstanceID},
|
||||
cache.CacheConfig{Connector: "memory"},
|
||||
)
|
||||
cache.Set(context.Background(), cached)
|
||||
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
caches: &Caches{
|
||||
milestones: cache,
|
||||
},
|
||||
}
|
||||
got, err := c.GetMilestonesReached(tt.args.ctx)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_milestonesCompleted(t *testing.T) {
|
||||
c := &Commands{
|
||||
caches: &Caches{
|
||||
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
|
||||
},
|
||||
}
|
||||
ctx := authz.WithInstanceID(context.Background(), "instanceID")
|
||||
arg := &MilestonesReached{
|
||||
InstanceID: "instanceID",
|
||||
InstanceCreated: true,
|
||||
AuthenticationSucceededOnInstance: true,
|
||||
ProjectCreated: true,
|
||||
ApplicationCreated: true,
|
||||
AuthenticationSucceededOnApplication: true,
|
||||
InstanceDeleted: false,
|
||||
}
|
||||
c.setCachedMilestonesReached(ctx, arg)
|
||||
got, ok := c.getCachedMilestonesReached(ctx)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, arg, got)
|
||||
}
|
||||
|
||||
func TestCommands_MilestonePushed(t *testing.T) {
|
||||
aggregate := milestone.NewInstanceAggregate("instanceID")
|
||||
type fields struct {
|
||||
eventstore func(*testing.T) *eventstore.Eventstore
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
instanceID string
|
||||
msType milestone.Type
|
||||
endpoints []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "milestone pushed",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectPush(
|
||||
milestone.NewPushedEvent(
|
||||
context.Background(),
|
||||
aggregate,
|
||||
milestone.ApplicationCreated,
|
||||
[]string{"foo.com", "bar.com"},
|
||||
"example.com",
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
instanceID: "instanceID",
|
||||
msType: milestone.ApplicationCreated,
|
||||
endpoints: []string{"foo.com", "bar.com"},
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "pusher error",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectPushFailed(
|
||||
io.ErrClosedPipe,
|
||||
milestone.NewPushedEvent(
|
||||
context.Background(),
|
||||
aggregate,
|
||||
milestone.ApplicationCreated,
|
||||
[]string{"foo.com", "bar.com"},
|
||||
"example.com",
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
instanceID: "instanceID",
|
||||
msType: milestone.ApplicationCreated,
|
||||
endpoints: []string{"foo.com", "bar.com"},
|
||||
},
|
||||
wantErr: io.ErrClosedPipe,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
externalDomain: "example.com",
|
||||
}
|
||||
err := c.MilestonePushed(tt.args.ctx, tt.args.instanceID, tt.args.msType, tt.args.endpoints)
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCSessionEvents_SetMilestones(t *testing.T) {
|
||||
ctx := authz.WithInstanceID(context.Background(), "instanceID")
|
||||
ctx = authz.WithConsoleClientID(ctx, "console")
|
||||
aggregate := milestone.NewAggregate(ctx)
|
||||
|
||||
type fields struct {
|
||||
eventstore func(*testing.T) *eventstore.Eventstore
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
clientID string
|
||||
isHuman bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantEvents []eventstore.Command
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "get error",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilterError(io.ErrClosedPipe),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
clientID: "client",
|
||||
isHuman: true,
|
||||
},
|
||||
wantErr: io.ErrClosedPipe,
|
||||
},
|
||||
{
|
||||
name: "milestones already reached",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
|
||||
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication)),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
clientID: "client",
|
||||
isHuman: true,
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "auth on instance",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
clientID: "console",
|
||||
isHuman: true,
|
||||
},
|
||||
wantEvents: []eventstore.Command{
|
||||
milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance),
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "subsequent console login, no milestone",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
clientID: "console",
|
||||
isHuman: true,
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "subsequent machine login, no milestone",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
clientID: "client",
|
||||
isHuman: false,
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "auth on app",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
clientID: "client",
|
||||
isHuman: true,
|
||||
},
|
||||
wantEvents: []eventstore.Command{
|
||||
milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication),
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
caches: &Caches{
|
||||
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
|
||||
},
|
||||
}
|
||||
s := &OIDCSessionEvents{
|
||||
commands: c,
|
||||
}
|
||||
postCommit, err := s.SetMilestones(tt.args.ctx, tt.args.clientID, tt.args.isHuman)
|
||||
postCommit(tt.args.ctx)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.wantEvents, s.events)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_projectCreatedMilestone(t *testing.T) {
|
||||
ctx := authz.WithInstanceID(context.Background(), "instanceID")
|
||||
systemCtx := authz.SetCtxData(ctx, authz.CtxData{
|
||||
SystemMemberships: authz.Memberships{
|
||||
&authz.Membership{
|
||||
MemberType: authz.MemberTypeSystem,
|
||||
},
|
||||
},
|
||||
})
|
||||
aggregate := milestone.NewAggregate(ctx)
|
||||
|
||||
type fields struct {
|
||||
eventstore func(*testing.T) *eventstore.Eventstore
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantEvents []eventstore.Command
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "system user",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(),
|
||||
},
|
||||
args: args{
|
||||
ctx: systemCtx,
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "get error",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilterError(io.ErrClosedPipe),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
},
|
||||
wantErr: io.ErrClosedPipe,
|
||||
},
|
||||
{
|
||||
name: "milestone already reached",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.ProjectCreated)),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "milestone reached event",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
},
|
||||
wantEvents: []eventstore.Command{
|
||||
milestone.NewReachedEvent(ctx, aggregate, milestone.ProjectCreated),
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
caches: &Caches{
|
||||
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
|
||||
},
|
||||
}
|
||||
var cmds []eventstore.Command
|
||||
postCommit, err := c.projectCreatedMilestone(tt.args.ctx, &cmds)
|
||||
postCommit(tt.args.ctx)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.wantEvents, cmds)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_applicationCreatedMilestone(t *testing.T) {
|
||||
ctx := authz.WithInstanceID(context.Background(), "instanceID")
|
||||
systemCtx := authz.SetCtxData(ctx, authz.CtxData{
|
||||
SystemMemberships: authz.Memberships{
|
||||
&authz.Membership{
|
||||
MemberType: authz.MemberTypeSystem,
|
||||
},
|
||||
},
|
||||
})
|
||||
aggregate := milestone.NewAggregate(ctx)
|
||||
|
||||
type fields struct {
|
||||
eventstore func(*testing.T) *eventstore.Eventstore
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantEvents []eventstore.Command
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "system user",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(),
|
||||
},
|
||||
args: args{
|
||||
ctx: systemCtx,
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "get error",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilterError(io.ErrClosedPipe),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
},
|
||||
wantErr: io.ErrClosedPipe,
|
||||
},
|
||||
{
|
||||
name: "milestone already reached",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.ApplicationCreated)),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "milestone reached event",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
},
|
||||
wantEvents: []eventstore.Command{
|
||||
milestone.NewReachedEvent(ctx, aggregate, milestone.ApplicationCreated),
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
caches: &Caches{
|
||||
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
|
||||
},
|
||||
}
|
||||
var cmds []eventstore.Command
|
||||
postCommit, err := c.applicationCreatedMilestone(tt.args.ctx, &cmds)
|
||||
postCommit(tt.args.ctx)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.wantEvents, cmds)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Commands) setMilestonesCompletedForTest(instanceID string) {
|
||||
c.milestonesCompleted.Store(instanceID, struct{}{})
|
||||
}
|
@@ -71,7 +71,8 @@ func (c *Commands) CreateOIDCSessionFromAuthRequest(ctx context.Context, authReq
|
||||
return nil, "", zerrors.ThrowPreconditionFailed(nil, "COMMAND-Iung5", "Errors.AuthRequest.NoCode")
|
||||
}
|
||||
|
||||
sessionModel := NewSessionWriteModel(authReqModel.SessionID, authz.GetInstance(ctx).InstanceID())
|
||||
instanceID := authz.GetInstance(ctx).InstanceID()
|
||||
sessionModel := NewSessionWriteModel(authReqModel.SessionID, instanceID)
|
||||
err = c.eventstore.FilterToQueryReducer(ctx, sessionModel)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
@@ -118,8 +119,15 @@ func (c *Commands) CreateOIDCSessionFromAuthRequest(ctx context.Context, authReq
|
||||
}
|
||||
}
|
||||
cmd.SetAuthRequestSuccessful(ctx, authReqModel.aggregate)
|
||||
session, err = cmd.PushEvents(ctx)
|
||||
return session, authReqModel.State, err
|
||||
postCommit, err := cmd.SetMilestones(ctx, authReqModel.ClientID, true)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if session, err = cmd.PushEvents(ctx); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
postCommit(ctx)
|
||||
return session, authReqModel.State, nil
|
||||
}
|
||||
|
||||
func (c *Commands) CreateOIDCSession(ctx context.Context,
|
||||
@@ -161,7 +169,15 @@ func (c *Commands) CreateOIDCSession(ctx context.Context,
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return cmd.PushEvents(ctx)
|
||||
postCommit, err := cmd.SetMilestones(ctx, clientID, sessionID != "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if session, err = cmd.PushEvents(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
postCommit(ctx)
|
||||
return session, nil
|
||||
}
|
||||
|
||||
type RefreshTokenComplianceChecker func(ctx context.Context, wm *OIDCSessionWriteModel, requestedScope []string) (scope []string, err error)
|
||||
@@ -283,7 +299,7 @@ func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, userID, resource
|
||||
}
|
||||
sessionID = IDPrefixV2 + sessionID
|
||||
return &OIDCSessionEvents{
|
||||
eventstore: c.eventstore,
|
||||
commands: c,
|
||||
idGenerator: c.idGenerator,
|
||||
encryptionAlg: c.keyAlgorithm,
|
||||
events: pending,
|
||||
@@ -341,7 +357,7 @@ func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, refreshToken
|
||||
return nil, err
|
||||
}
|
||||
return &OIDCSessionEvents{
|
||||
eventstore: c.eventstore,
|
||||
commands: c,
|
||||
idGenerator: c.idGenerator,
|
||||
encryptionAlg: c.keyAlgorithm,
|
||||
oidcSessionWriteModel: sessionWriteModel,
|
||||
@@ -352,7 +368,7 @@ func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, refreshToken
|
||||
}
|
||||
|
||||
type OIDCSessionEvents struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
commands *Commands
|
||||
idGenerator id.Generator
|
||||
encryptionAlg crypto.EncryptionAlgorithm
|
||||
events []eventstore.Command
|
||||
@@ -467,7 +483,7 @@ func (c *OIDCSessionEvents) generateRefreshToken(userID string) (refreshTokenID,
|
||||
}
|
||||
|
||||
func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (*OIDCSession, error) {
|
||||
pushedEvents, err := c.eventstore.Push(ctx, c.events...)
|
||||
pushedEvents, err := c.commands.eventstore.Push(ctx, c.events...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -496,7 +512,7 @@ func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (*OIDCSession, error
|
||||
// we need to use `-` as a delimiter because the OIDC library uses `:` and will check for a length of 2 parts
|
||||
session.TokenID = c.oidcSessionWriteModel.AggregateID + TokenDelimiter + c.accessTokenID
|
||||
}
|
||||
activity.Trigger(ctx, c.oidcSessionWriteModel.UserResourceOwner, c.oidcSessionWriteModel.UserID, tokenReasonToActivityMethodType(c.oidcSessionWriteModel.AccessTokenReason), c.eventstore.FilterToQueryReducer)
|
||||
activity.Trigger(ctx, c.oidcSessionWriteModel.UserResourceOwner, c.oidcSessionWriteModel.UserID, tokenReasonToActivityMethodType(c.oidcSessionWriteModel.AccessTokenReason), c.commands.eventstore.FilterToQueryReducer)
|
||||
return session, nil
|
||||
}
|
||||
|
||||
|
@@ -70,7 +70,7 @@ func TestCommands_CreateOIDCSessionFromAuthRequest(t *testing.T) {
|
||||
eventstore: expectEventstore(),
|
||||
},
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
authRequestID: "",
|
||||
complianceCheck: mockAuthRequestComplianceChecker(nil),
|
||||
},
|
||||
@@ -86,7 +86,7 @@ func TestCommands_CreateOIDCSessionFromAuthRequest(t *testing.T) {
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
authRequestID: "V2_authRequestID",
|
||||
complianceCheck: mockAuthRequestComplianceChecker(nil),
|
||||
},
|
||||
@@ -102,7 +102,7 @@ func TestCommands_CreateOIDCSessionFromAuthRequest(t *testing.T) {
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
authRequestID: "V2_authRequestID",
|
||||
complianceCheck: mockAuthRequestComplianceChecker(nil),
|
||||
},
|
||||
@@ -706,6 +706,7 @@ func TestCommands_CreateOIDCSessionFromAuthRequest(t *testing.T) {
|
||||
defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime,
|
||||
keyAlgorithm: tt.fields.keyAlgorithm,
|
||||
}
|
||||
c.setMilestonesCompletedForTest("instanceID")
|
||||
gotSession, gotState, err := c.CreateOIDCSessionFromAuthRequest(tt.args.ctx, tt.args.authRequestID, tt.args.complianceCheck, tt.args.needRefreshToken)
|
||||
require.ErrorIs(t, err, tt.res.err)
|
||||
|
||||
@@ -762,7 +763,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
userID: "userID",
|
||||
resourceOwner: "orgID",
|
||||
clientID: "clientID",
|
||||
@@ -818,7 +819,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
userID: "userID",
|
||||
resourceOwner: "org1",
|
||||
clientID: "clientID",
|
||||
@@ -892,7 +893,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
userID: "userID",
|
||||
resourceOwner: "org1",
|
||||
clientID: "clientID",
|
||||
@@ -1089,7 +1090,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
userID: "userID",
|
||||
resourceOwner: "org1",
|
||||
clientID: "clientID",
|
||||
@@ -1186,7 +1187,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
userID: "userID",
|
||||
resourceOwner: "org1",
|
||||
clientID: "clientID",
|
||||
@@ -1266,7 +1267,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
|
||||
}),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
userID: "userID",
|
||||
resourceOwner: "org1",
|
||||
clientID: "clientID",
|
||||
@@ -1347,7 +1348,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
|
||||
}),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
userID: "userID",
|
||||
resourceOwner: "org1",
|
||||
clientID: "clientID",
|
||||
@@ -1406,6 +1407,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
|
||||
keyAlgorithm: tt.fields.keyAlgorithm,
|
||||
checkPermission: tt.fields.checkPermission,
|
||||
}
|
||||
c.setMilestonesCompletedForTest("instanceID")
|
||||
got, err := c.CreateOIDCSession(tt.args.ctx,
|
||||
tt.args.userID,
|
||||
tt.args.resourceOwner,
|
||||
|
@@ -34,7 +34,11 @@ func (c *Commands) AddProjectWithID(ctx context.Context, project *domain.Project
|
||||
if existingProject.State != domain.ProjectStateUnspecified {
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-opamwu", "Errors.Project.AlreadyExisting")
|
||||
}
|
||||
return c.addProjectWithID(ctx, project, resourceOwner, projectID)
|
||||
project, err = c.addProjectWithID(ctx, project, resourceOwner, projectID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return project, nil
|
||||
}
|
||||
|
||||
func (c *Commands) AddProject(ctx context.Context, project *domain.Project, resourceOwner, ownerUserID string) (_ *domain.Project, err error) {
|
||||
@@ -53,7 +57,11 @@ func (c *Commands) AddProject(ctx context.Context, project *domain.Project, reso
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.addProjectWithIDWithOwner(ctx, project, resourceOwner, ownerUserID, projectID)
|
||||
project, err = c.addProjectWithIDWithOwner(ctx, project, resourceOwner, ownerUserID, projectID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return project, nil
|
||||
}
|
||||
|
||||
func (c *Commands) addProjectWithID(ctx context.Context, projectAdd *domain.Project, resourceOwner, projectID string) (_ *domain.Project, err error) {
|
||||
@@ -71,11 +79,15 @@ func (c *Commands) addProjectWithID(ctx context.Context, projectAdd *domain.Proj
|
||||
projectAdd.HasProjectCheck,
|
||||
projectAdd.PrivateLabelingSetting),
|
||||
}
|
||||
|
||||
postCommit, err := c.projectCreatedMilestone(ctx, &events)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pushedEvents, err := c.eventstore.Push(ctx, events...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
postCommit(ctx)
|
||||
err = AppendAndReduce(addedProject, pushedEvents...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -103,11 +115,15 @@ func (c *Commands) addProjectWithIDWithOwner(ctx context.Context, projectAdd *do
|
||||
projectAdd.PrivateLabelingSetting),
|
||||
project.NewProjectMemberAddedEvent(ctx, projectAgg, ownerUserID, projectRole),
|
||||
}
|
||||
|
||||
postCommit, err := c.projectCreatedMilestone(ctx, &events)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pushedEvents, err := c.eventstore.Push(ctx, events...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
postCommit(ctx)
|
||||
err = AppendAndReduce(addedProject, pushedEvents...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@@ -202,10 +202,15 @@ func (c *Commands) addOIDCApplicationWithID(ctx context.Context, oidcApp *domain
|
||||
))
|
||||
|
||||
addedApplication.AppID = oidcApp.AppID
|
||||
postCommit, err := c.applicationCreatedMilestone(ctx, &events)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pushedEvents, err := c.eventstore.Push(ctx, events...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
postCommit(ctx)
|
||||
err = AppendAndReduce(addedApplication, pushedEvents...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/zitadel/passwap"
|
||||
"github.com/zitadel/passwap/bcrypt"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/command/preparation"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
@@ -418,7 +419,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
|
||||
eventstore: expectEventstore(),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
oidcApp: &domain.OIDCApp{},
|
||||
resourceOwner: "org1",
|
||||
},
|
||||
@@ -434,7 +435,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
oidcApp: &domain.OIDCApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
@@ -463,7 +464,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
oidcApp: &domain.OIDCApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
@@ -521,7 +522,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
|
||||
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "app1", "client1"),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
oidcApp: &domain.OIDCApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
@@ -619,7 +620,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
|
||||
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "app1", "client1"),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
oidcApp: &domain.OIDCApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
@@ -676,7 +677,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := &Commands{
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
newHashedSecret: mockHashedSecret("secret"),
|
||||
@@ -684,7 +685,8 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
|
||||
ClientSecret: emptyConfig,
|
||||
},
|
||||
}
|
||||
got, err := r.AddOIDCApplication(tt.args.ctx, tt.args.oidcApp, tt.args.resourceOwner)
|
||||
c.setMilestonesCompletedForTest("instanceID")
|
||||
got, err := c.AddOIDCApplication(tt.args.ctx, tt.args.oidcApp, tt.args.resourceOwner)
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
@@ -28,10 +28,15 @@ func (c *Commands) AddSAMLApplication(ctx context.Context, application *domain.S
|
||||
return nil, err
|
||||
}
|
||||
addedApplication.AppID = application.AppID
|
||||
postCommit, err := c.applicationCreatedMilestone(ctx, &events)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pushedEvents, err := c.eventstore.Push(ctx, events...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
postCommit(ctx)
|
||||
err = AppendAndReduce(addedApplication, pushedEvents...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
@@ -76,7 +77,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
samlApp: &domain.SAMLApp{},
|
||||
resourceOwner: "org1",
|
||||
},
|
||||
@@ -93,7 +94,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
samlApp: &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
@@ -123,7 +124,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
samlApp: &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
@@ -154,7 +155,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
|
||||
idGenerator: id_mock.NewIDGeneratorExpectIDs(t),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
samlApp: &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
@@ -201,7 +202,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
|
||||
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "app1"),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
samlApp: &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
@@ -260,7 +261,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
|
||||
httpClient: newTestClient(200, testMetadata),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
samlApp: &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
@@ -305,7 +306,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
|
||||
httpClient: newTestClient(http.StatusNotFound, nil),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
samlApp: &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
@@ -325,13 +326,13 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := &Commands{
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
httpClient: tt.fields.httpClient,
|
||||
}
|
||||
|
||||
got, err := r.AddSAMLApplication(tt.args.ctx, tt.args.samlApp, tt.args.resourceOwner)
|
||||
c.setMilestonesCompletedForTest("instanceID")
|
||||
got, err := c.AddSAMLApplication(tt.args.ctx, tt.args.samlApp, tt.args.resourceOwner)
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
@@ -44,7 +45,7 @@ func TestCommandSide_AddProject(t *testing.T) {
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
project: &domain.Project{},
|
||||
resourceOwner: "org1",
|
||||
},
|
||||
@@ -60,7 +61,7 @@ func TestCommandSide_AddProject(t *testing.T) {
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
project: &domain.Project{
|
||||
Name: "project",
|
||||
ProjectRoleAssertion: true,
|
||||
@@ -121,7 +122,7 @@ func TestCommandSide_AddProject(t *testing.T) {
|
||||
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "project1"),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
project: &domain.Project{
|
||||
Name: "project",
|
||||
ProjectRoleAssertion: true,
|
||||
@@ -159,7 +160,7 @@ func TestCommandSide_AddProject(t *testing.T) {
|
||||
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "project1"),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
project: &domain.Project{
|
||||
Name: "project",
|
||||
ProjectRoleAssertion: true,
|
||||
@@ -187,11 +188,12 @@ func TestCommandSide_AddProject(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := &Commands{
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
}
|
||||
got, err := r.AddProject(tt.args.ctx, tt.args.project, tt.args.resourceOwner, tt.args.ownerID)
|
||||
c.setMilestonesCompletedForTest("instanceID")
|
||||
got, err := c.AddProject(tt.args.ctx, tt.args.project, tt.args.resourceOwner, tt.args.ownerID)
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
Reference in New Issue
Block a user