perf(milestones): refactor (#8788)
Some checks are pending
ZITADEL CI/CD / core (push) Waiting to run
ZITADEL CI/CD / console (push) Waiting to run
ZITADEL CI/CD / version (push) Waiting to run
ZITADEL CI/CD / compile (push) Blocked by required conditions
ZITADEL CI/CD / core-unit-test (push) Blocked by required conditions
ZITADEL CI/CD / core-integration-test (push) Blocked by required conditions
ZITADEL CI/CD / lint (push) Blocked by required conditions
ZITADEL CI/CD / container (push) Blocked by required conditions
ZITADEL CI/CD / e2e (push) Blocked by required conditions
ZITADEL CI/CD / release (push) Blocked by required conditions
Code Scanning / CodeQL-Build (go) (push) Waiting to run
Code Scanning / CodeQL-Build (javascript) (push) Waiting to run

# Which Problems Are Solved

Milestones used existing events from a number of aggregates. OIDC
session is one of them. We noticed in load-tests that the reduction of
the oidc_session.added event into the milestone projection is a costly
business with payload based conditionals. A milestone is reached once,
but even then we remain subscribed to the OIDC events. This requires the
projections.current_states to be updated continuously.


# How the Problems Are Solved

The milestone creation is refactored to use dedicated events instead.
The command side decides when a milestone is reached and creates the
reached event once for each milestone when required.

# Additional Changes

In order to prevent reached milestones being created twice, a migration
script is provided. When the old `projections.milestones` table exist,
the state is read from there and `v2` milestone aggregate events are
created, with the original reached and pushed dates.

# Additional Context

- Closes https://github.com/zitadel/zitadel/issues/8800
This commit is contained in:
Tim Möhlmann
2024-10-28 09:29:34 +01:00
committed by GitHub
parent 54f1c0bc50
commit 32bad3feb3
46 changed files with 1612 additions and 756 deletions

82
internal/command/cache.go Normal file
View File

@@ -0,0 +1,82 @@
package command
import (
"context"
"fmt"
"strings"
"github.com/zitadel/zitadel/internal/cache"
"github.com/zitadel/zitadel/internal/cache/gomap"
"github.com/zitadel/zitadel/internal/cache/noop"
"github.com/zitadel/zitadel/internal/cache/pg"
"github.com/zitadel/zitadel/internal/database"
)
type Caches struct {
connectors *cacheConnectors
milestones cache.Cache[milestoneIndex, string, *MilestonesReached]
}
func startCaches(background context.Context, conf *cache.CachesConfig, client *database.DB) (_ *Caches, err error) {
caches := &Caches{
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
}
if conf == nil {
return caches, nil
}
caches.connectors, err = startCacheConnectors(background, conf, client)
if err != nil {
return nil, err
}
caches.milestones, err = startCache[milestoneIndex, string, *MilestonesReached](background, []milestoneIndex{milestoneIndexInstanceID}, "milestones", conf.Instance, caches.connectors)
if err != nil {
return nil, err
}
return caches, nil
}
type cacheConnectors struct {
memory *cache.AutoPruneConfig
postgres *pgxPoolCacheConnector
}
type pgxPoolCacheConnector struct {
*cache.AutoPruneConfig
client *database.DB
}
func startCacheConnectors(_ context.Context, conf *cache.CachesConfig, client *database.DB) (_ *cacheConnectors, err error) {
connectors := new(cacheConnectors)
if conf.Connectors.Memory.Enabled {
connectors.memory = &conf.Connectors.Memory.AutoPrune
}
if conf.Connectors.Postgres.Enabled {
connectors.postgres = &pgxPoolCacheConnector{
AutoPruneConfig: &conf.Connectors.Postgres.AutoPrune,
client: client,
}
}
return connectors, nil
}
func startCache[I ~int, K ~string, V cache.Entry[I, K]](background context.Context, indices []I, name string, conf *cache.CacheConfig, connectors *cacheConnectors) (cache.Cache[I, K, V], error) {
if conf == nil || conf.Connector == "" {
return noop.NewCache[I, K, V](), nil
}
if strings.EqualFold(conf.Connector, "memory") && connectors.memory != nil {
c := gomap.NewCache[I, K, V](background, indices, *conf)
connectors.memory.StartAutoPrune(background, c, name)
return c, nil
}
if strings.EqualFold(conf.Connector, "postgres") && connectors.postgres != nil {
client := connectors.postgres.client
c, err := pg.NewCache[I, K, V](background, name, *conf, indices, client.Pool, client.Type())
if err != nil {
return nil, fmt.Errorf("query start cache: %w", err)
}
connectors.postgres.StartAutoPrune(background, c, name)
return c, nil
}
return nil, fmt.Errorf("cache connector %q not enabled", conf.Connector)
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/zitadel/zitadel/internal/api/authz"
api_http "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/cache"
"github.com/zitadel/zitadel/internal/command/preparation"
sd "github.com/zitadel/zitadel/internal/config/systemdefaults"
"github.com/zitadel/zitadel/internal/crypto"
@@ -88,10 +89,17 @@ type Commands struct {
EventGroupExisting func(group string) bool
GenerateDomain func(instanceName, domain string) (string, error)
caches *Caches
// Store instance IDs where all milestones are reached (except InstanceDeleted).
// These instance's milestones never need to be invalidated,
// so the query and cache overhead can completely eliminated.
milestonesCompleted sync.Map
}
func StartCommands(
es *eventstore.Eventstore,
cachesConfig *cache.CachesConfig,
defaults sd.SystemDefaults,
zitadelRoles []authz.RoleMapping,
staticStore static.Storage,
@@ -123,6 +131,10 @@ func StartCommands(
if err != nil {
return nil, fmt.Errorf("password hasher: %w", err)
}
caches, err := startCaches(context.TODO(), cachesConfig, es.Client())
if err != nil {
return nil, fmt.Errorf("caches: %w", err)
}
repo = &Commands{
eventstore: es,
static: staticStore,
@@ -176,6 +188,7 @@ func StartCommands(
},
},
GenerateDomain: domain.NewGeneratedInstanceDomain,
caches: caches,
}
if defaultSecretGenerators != nil && defaultSecretGenerators.ClientSecret != nil {

View File

@@ -6,6 +6,7 @@ import (
"golang.org/x/text/language"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/command/preparation"
@@ -17,6 +18,7 @@ import (
"github.com/zitadel/zitadel/internal/notification/channels/smtp"
"github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/limits"
"github.com/zitadel/zitadel/internal/repository/milestone"
"github.com/zitadel/zitadel/internal/repository/org"
"github.com/zitadel/zitadel/internal/repository/project"
"github.com/zitadel/zitadel/internal/repository/quota"
@@ -292,7 +294,7 @@ func setUpInstance(ctx context.Context, c *Commands, setup *InstanceSetup) (vali
setupFeatures(&validations, setup.Features, setup.zitadel.instanceID)
setupLimits(c, &validations, limits.NewAggregate(setup.zitadel.limitsID, setup.zitadel.instanceID), setup.Limits)
setupRestrictions(c, &validations, restrictions.NewAggregate(setup.zitadel.restrictionsID, setup.zitadel.instanceID, setup.zitadel.instanceID), setup.Restrictions)
setupInstanceCreatedMilestone(&validations, setup.zitadel.instanceID)
return validations, pat, machineKey, nil
}
@@ -890,7 +892,8 @@ func (c *Commands) RemoveInstance(ctx context.Context, id string) (*domain.Objec
if err != nil {
return nil, err
}
err = c.caches.milestones.Invalidate(ctx, milestoneIndexInstanceID, id)
logging.OnError(err).Error("milestone invalidate")
return &domain.ObjectDetails{
Sequence: events[len(events)-1].Sequence(),
EventDate: events[len(events)-1].CreatedAt(),
@@ -908,10 +911,16 @@ func (c *Commands) prepareRemoveInstance(a *instance.Aggregate) preparation.Vali
if !writeModel.State.Exists() {
return nil, zerrors.ThrowNotFound(err, "COMMA-AE3GS", "Errors.Instance.NotFound")
}
return []eventstore.Command{instance.NewInstanceRemovedEvent(ctx,
&a.Aggregate,
writeModel.Name,
writeModel.Domains)},
milestoneAggregate := milestone.NewInstanceAggregate(a.ID)
return []eventstore.Command{
instance.NewInstanceRemovedEvent(ctx,
&a.Aggregate,
writeModel.Name,
writeModel.Domains),
milestone.NewReachedEvent(ctx,
milestoneAggregate,
milestone.InstanceDeleted),
},
nil
}, nil
}

View File

@@ -13,6 +13,7 @@ import (
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/cache/noop"
"github.com/zitadel/zitadel/internal/command/preparation"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
@@ -20,6 +21,7 @@ import (
"github.com/zitadel/zitadel/internal/id"
id_mock "github.com/zitadel/zitadel/internal/id/mock"
"github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/milestone"
"github.com/zitadel/zitadel/internal/repository/org"
"github.com/zitadel/zitadel/internal/repository/project"
"github.com/zitadel/zitadel/internal/repository/user"
@@ -372,6 +374,7 @@ func setupInstanceEvents(ctx context.Context, instanceID, orgID, projectID, appI
setupInstanceElementsEvents(ctx, instanceID, instanceName, defaultLanguage),
orgEvents(ctx, instanceID, orgID, orgName, projectID, domain, externalSecure, true, true),
generatedDomainEvents(ctx, instanceID, orgID, projectID, appID, domain),
instanceCreatedMilestoneEvent(ctx, instanceID),
)
}
@@ -401,6 +404,12 @@ func generatedDomainEvents(ctx context.Context, instanceID, orgID, projectID, ap
}
}
func instanceCreatedMilestoneEvent(ctx context.Context, instanceID string) []eventstore.Command {
return []eventstore.Command{
milestone.NewReachedEvent(ctx, milestone.NewInstanceAggregate(instanceID), milestone.InstanceCreated),
}
}
func generatedDomainFilters(instanceID, orgID, projectID, appID, generatedDomain string) []expect {
return []expect{
expectFilter(),
@@ -1378,7 +1387,7 @@ func TestCommandSide_UpdateInstance(t *testing.T) {
func TestCommandSide_RemoveInstance(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
eventstore func(t *testing.T) *eventstore.Eventstore
}
type args struct {
ctx context.Context
@@ -1397,8 +1406,7 @@ func TestCommandSide_RemoveInstance(t *testing.T) {
{
name: "instance not existing, not found error",
fields: fields{
eventstore: eventstoreExpect(
t,
eventstore: expectEventstore(
expectFilter(),
),
},
@@ -1413,8 +1421,7 @@ func TestCommandSide_RemoveInstance(t *testing.T) {
{
name: "instance removed, not found error",
fields: fields{
eventstore: eventstoreExpect(
t,
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
instance.NewInstanceAddedEvent(
@@ -1444,8 +1451,7 @@ func TestCommandSide_RemoveInstance(t *testing.T) {
{
name: "instance remove, ok",
fields: fields{
eventstore: eventstoreExpect(
t,
eventstore: expectEventstore(
expectFilter(
eventFromEventPusherWithInstanceID(
"INSTANCE",
@@ -1480,6 +1486,10 @@ func TestCommandSide_RemoveInstance(t *testing.T) {
"custom.domain",
},
),
milestone.NewReachedEvent(context.Background(),
milestone.NewInstanceAggregate("INSTANCE"),
milestone.InstanceDeleted,
),
),
),
},
@@ -1497,7 +1507,10 @@ func TestCommandSide_RemoveInstance(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &Commands{
eventstore: tt.fields.eventstore,
eventstore: tt.fields.eventstore(t),
caches: &Caches{
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
},
}
got, err := r.RemoveInstance(tt.args.ctx, tt.args.instanceID)
if tt.res.err == nil {

View File

@@ -3,20 +3,176 @@ package command
import (
"context"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/command/preparation"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/milestone"
)
// MilestonePushed writes a new milestone.PushedEvent with a new milestone.Aggregate to the eventstore
type milestoneIndex int
const (
milestoneIndexInstanceID milestoneIndex = iota
)
type MilestonesReached struct {
InstanceID string
InstanceCreated bool
AuthenticationSucceededOnInstance bool
ProjectCreated bool
ApplicationCreated bool
AuthenticationSucceededOnApplication bool
InstanceDeleted bool
}
// complete returns true if all milestones except InstanceDeleted are reached.
func (m *MilestonesReached) complete() bool {
return m.InstanceCreated &&
m.AuthenticationSucceededOnInstance &&
m.ProjectCreated &&
m.ApplicationCreated &&
m.AuthenticationSucceededOnApplication
}
// GetMilestonesReached finds the milestone state for the current instance.
func (c *Commands) GetMilestonesReached(ctx context.Context) (*MilestonesReached, error) {
milestones, ok := c.getCachedMilestonesReached(ctx)
if ok {
return milestones, nil
}
model := NewMilestonesReachedWriteModel(authz.GetInstance(ctx).InstanceID())
if err := c.eventstore.FilterToQueryReducer(ctx, model); err != nil {
return nil, err
}
milestones = &model.MilestonesReached
c.setCachedMilestonesReached(ctx, milestones)
return milestones, nil
}
// getCachedMilestonesReached checks for milestone completeness on an instance and returns a filled
// [MilestonesReached] object.
// Otherwise it looks for the object in the milestone cache.
func (c *Commands) getCachedMilestonesReached(ctx context.Context) (*MilestonesReached, bool) {
instanceID := authz.GetInstance(ctx).InstanceID()
if _, ok := c.milestonesCompleted.Load(instanceID); ok {
return &MilestonesReached{
InstanceID: instanceID,
InstanceCreated: true,
AuthenticationSucceededOnInstance: true,
ProjectCreated: true,
ApplicationCreated: true,
AuthenticationSucceededOnApplication: true,
InstanceDeleted: false,
}, ok
}
return c.caches.milestones.Get(ctx, milestoneIndexInstanceID, instanceID)
}
// setCachedMilestonesReached stores the current milestones state in the milestones cache.
// If the milestones are complete, the instance ID is stored in milestonesCompleted instead.
func (c *Commands) setCachedMilestonesReached(ctx context.Context, milestones *MilestonesReached) {
if milestones.complete() {
c.milestonesCompleted.Store(milestones.InstanceID, struct{}{})
return
}
c.caches.milestones.Set(ctx, milestones)
}
// Keys implements cache.Entry
func (c *MilestonesReached) Keys(i milestoneIndex) []string {
if i == milestoneIndexInstanceID {
return []string{c.InstanceID}
}
return nil
}
// MilestonePushed writes a new milestone.PushedEvent with the milestone.Aggregate to the eventstore
func (c *Commands) MilestonePushed(
ctx context.Context,
instanceID string,
msType milestone.Type,
endpoints []string,
primaryDomain string,
) error {
id, err := c.idGenerator.Next()
if err != nil {
return err
}
_, err = c.eventstore.Push(ctx, milestone.NewPushedEvent(ctx, milestone.NewAggregate(ctx, id), msType, endpoints, primaryDomain, c.externalDomain))
_, err := c.eventstore.Push(ctx, milestone.NewPushedEvent(ctx, milestone.NewInstanceAggregate(instanceID), msType, endpoints, c.externalDomain))
return err
}
func setupInstanceCreatedMilestone(validations *[]preparation.Validation, instanceID string) {
*validations = append(*validations, func() (preparation.CreateCommands, error) {
return func(ctx context.Context, _ preparation.FilterToQueryReducer) ([]eventstore.Command, error) {
return []eventstore.Command{
milestone.NewReachedEvent(ctx, milestone.NewInstanceAggregate(instanceID), milestone.InstanceCreated),
}, nil
}, nil
})
}
func (s *OIDCSessionEvents) SetMilestones(ctx context.Context, clientID string, isHuman bool) (postCommit func(ctx context.Context), err error) {
postCommit = func(ctx context.Context) {}
milestones, err := s.commands.GetMilestonesReached(ctx)
if err != nil {
return postCommit, err
}
instance := authz.GetInstance(ctx)
aggregate := milestone.NewAggregate(ctx)
var invalidate bool
if !milestones.AuthenticationSucceededOnInstance {
s.events = append(s.events, milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance))
invalidate = true
}
if !milestones.AuthenticationSucceededOnApplication && isHuman && clientID != instance.ConsoleClientID() {
s.events = append(s.events, milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication))
invalidate = true
}
if invalidate {
postCommit = s.commands.invalidateMilestoneCachePostCommit(instance.InstanceID())
}
return postCommit, nil
}
func (c *Commands) projectCreatedMilestone(ctx context.Context, cmds *[]eventstore.Command) (postCommit func(ctx context.Context), err error) {
postCommit = func(ctx context.Context) {}
if isSystemUser(ctx) {
return postCommit, nil
}
milestones, err := c.GetMilestonesReached(ctx)
if err != nil {
return postCommit, err
}
if milestones.ProjectCreated {
return postCommit, nil
}
aggregate := milestone.NewAggregate(ctx)
*cmds = append(*cmds, milestone.NewReachedEvent(ctx, aggregate, milestone.ProjectCreated))
return c.invalidateMilestoneCachePostCommit(aggregate.InstanceID), nil
}
func (c *Commands) applicationCreatedMilestone(ctx context.Context, cmds *[]eventstore.Command) (postCommit func(ctx context.Context), err error) {
postCommit = func(ctx context.Context) {}
if isSystemUser(ctx) {
return postCommit, nil
}
milestones, err := c.GetMilestonesReached(ctx)
if err != nil {
return postCommit, err
}
if milestones.ApplicationCreated {
return postCommit, nil
}
aggregate := milestone.NewAggregate(ctx)
*cmds = append(*cmds, milestone.NewReachedEvent(ctx, aggregate, milestone.ApplicationCreated))
return c.invalidateMilestoneCachePostCommit(aggregate.InstanceID), nil
}
func (c *Commands) invalidateMilestoneCachePostCommit(instanceID string) func(ctx context.Context) {
return func(ctx context.Context) {
err := c.caches.milestones.Invalidate(ctx, milestoneIndexInstanceID, instanceID)
logging.WithFields("instance_id", instanceID).OnError(err).Error("failed to invalidate milestone cache")
}
}
func isSystemUser(ctx context.Context) bool {
return authz.GetCtxData(ctx).SystemMemberships != nil
}

View File

@@ -0,0 +1,58 @@
package command
import (
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/milestone"
)
type MilestonesReachedWriteModel struct {
eventstore.WriteModel
MilestonesReached
}
func NewMilestonesReachedWriteModel(instanceID string) *MilestonesReachedWriteModel {
return &MilestonesReachedWriteModel{
WriteModel: eventstore.WriteModel{
AggregateID: instanceID,
InstanceID: instanceID,
},
MilestonesReached: MilestonesReached{
InstanceID: instanceID,
},
}
}
func (m *MilestonesReachedWriteModel) Query() *eventstore.SearchQueryBuilder {
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes(milestone.AggregateType).
AggregateIDs(m.AggregateID).
EventTypes(milestone.ReachedEventType, milestone.PushedEventType).
Builder()
}
func (m *MilestonesReachedWriteModel) Reduce() error {
for _, event := range m.Events {
if e, ok := event.(*milestone.ReachedEvent); ok {
m.reduceReachedEvent(e)
}
}
return m.WriteModel.Reduce()
}
func (m *MilestonesReachedWriteModel) reduceReachedEvent(e *milestone.ReachedEvent) {
switch e.MilestoneType {
case milestone.InstanceCreated:
m.InstanceCreated = true
case milestone.AuthenticationSucceededOnInstance:
m.AuthenticationSucceededOnInstance = true
case milestone.ProjectCreated:
m.ProjectCreated = true
case milestone.ApplicationCreated:
m.ApplicationCreated = true
case milestone.AuthenticationSucceededOnApplication:
m.AuthenticationSucceededOnApplication = true
case milestone.InstanceDeleted:
m.InstanceDeleted = true
}
}

View File

@@ -0,0 +1,629 @@
package command
import (
"context"
"io"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/cache"
"github.com/zitadel/zitadel/internal/cache/gomap"
"github.com/zitadel/zitadel/internal/cache/noop"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/milestone"
)
func TestCommands_GetMilestonesReached(t *testing.T) {
cached := &MilestonesReached{
InstanceID: "cached-id",
InstanceCreated: true,
AuthenticationSucceededOnInstance: true,
}
ctx := authz.WithInstanceID(context.Background(), "instanceID")
aggregate := milestone.NewAggregate(ctx)
type fields struct {
eventstore func(*testing.T) *eventstore.Eventstore
}
type args struct {
ctx context.Context
}
tests := []struct {
name string
fields fields
args args
want *MilestonesReached
wantErr error
}{
{
name: "cached",
fields: fields{
eventstore: expectEventstore(),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "cached-id"),
},
want: cached,
},
{
name: "filter error",
fields: fields{
eventstore: expectEventstore(
expectFilterError(io.ErrClosedPipe),
),
},
args: args{
ctx: ctx,
},
wantErr: io.ErrClosedPipe,
},
{
name: "no events, all false",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
},
},
{
name: "instance created",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.InstanceCreated)),
),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
InstanceCreated: true,
},
},
{
name: "instance auth",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
AuthenticationSucceededOnInstance: true,
},
},
{
name: "project created",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.ProjectCreated)),
),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
ProjectCreated: true,
},
},
{
name: "app created",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.ApplicationCreated)),
),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
ApplicationCreated: true,
},
},
{
name: "app auth",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication)),
),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
AuthenticationSucceededOnApplication: true,
},
},
{
name: "instance deleted",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.InstanceDeleted)),
),
),
},
args: args{
ctx: ctx,
},
want: &MilestonesReached{
InstanceID: "instanceID",
InstanceDeleted: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cache := gomap.NewCache[milestoneIndex, string, *MilestonesReached](
context.Background(),
[]milestoneIndex{milestoneIndexInstanceID},
cache.CacheConfig{Connector: "memory"},
)
cache.Set(context.Background(), cached)
c := &Commands{
eventstore: tt.fields.eventstore(t),
caches: &Caches{
milestones: cache,
},
}
got, err := c.GetMilestonesReached(tt.args.ctx)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
})
}
}
func TestCommands_milestonesCompleted(t *testing.T) {
c := &Commands{
caches: &Caches{
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
},
}
ctx := authz.WithInstanceID(context.Background(), "instanceID")
arg := &MilestonesReached{
InstanceID: "instanceID",
InstanceCreated: true,
AuthenticationSucceededOnInstance: true,
ProjectCreated: true,
ApplicationCreated: true,
AuthenticationSucceededOnApplication: true,
InstanceDeleted: false,
}
c.setCachedMilestonesReached(ctx, arg)
got, ok := c.getCachedMilestonesReached(ctx)
assert.True(t, ok)
assert.Equal(t, arg, got)
}
func TestCommands_MilestonePushed(t *testing.T) {
aggregate := milestone.NewInstanceAggregate("instanceID")
type fields struct {
eventstore func(*testing.T) *eventstore.Eventstore
}
type args struct {
ctx context.Context
instanceID string
msType milestone.Type
endpoints []string
}
tests := []struct {
name string
fields fields
args args
wantErr error
}{
{
name: "milestone pushed",
fields: fields{
eventstore: expectEventstore(
expectPush(
milestone.NewPushedEvent(
context.Background(),
aggregate,
milestone.ApplicationCreated,
[]string{"foo.com", "bar.com"},
"example.com",
),
),
),
},
args: args{
ctx: context.Background(),
instanceID: "instanceID",
msType: milestone.ApplicationCreated,
endpoints: []string{"foo.com", "bar.com"},
},
wantErr: nil,
},
{
name: "pusher error",
fields: fields{
eventstore: expectEventstore(
expectPushFailed(
io.ErrClosedPipe,
milestone.NewPushedEvent(
context.Background(),
aggregate,
milestone.ApplicationCreated,
[]string{"foo.com", "bar.com"},
"example.com",
),
),
),
},
args: args{
ctx: context.Background(),
instanceID: "instanceID",
msType: milestone.ApplicationCreated,
endpoints: []string{"foo.com", "bar.com"},
},
wantErr: io.ErrClosedPipe,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore(t),
externalDomain: "example.com",
}
err := c.MilestonePushed(tt.args.ctx, tt.args.instanceID, tt.args.msType, tt.args.endpoints)
assert.ErrorIs(t, err, tt.wantErr)
})
}
}
func TestOIDCSessionEvents_SetMilestones(t *testing.T) {
ctx := authz.WithInstanceID(context.Background(), "instanceID")
ctx = authz.WithConsoleClientID(ctx, "console")
aggregate := milestone.NewAggregate(ctx)
type fields struct {
eventstore func(*testing.T) *eventstore.Eventstore
}
type args struct {
ctx context.Context
clientID string
isHuman bool
}
tests := []struct {
name string
fields fields
args args
wantEvents []eventstore.Command
wantErr error
}{
{
name: "get error",
fields: fields{
eventstore: expectEventstore(
expectFilterError(io.ErrClosedPipe),
),
},
args: args{
ctx: ctx,
clientID: "client",
isHuman: true,
},
wantErr: io.ErrClosedPipe,
},
{
name: "milestones already reached",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication)),
),
),
},
args: args{
ctx: ctx,
clientID: "client",
isHuman: true,
},
wantErr: nil,
},
{
name: "auth on instance",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
),
},
args: args{
ctx: ctx,
clientID: "console",
isHuman: true,
},
wantEvents: []eventstore.Command{
milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance),
},
wantErr: nil,
},
{
name: "subsequent console login, no milestone",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
),
),
},
args: args{
ctx: ctx,
clientID: "console",
isHuman: true,
},
wantErr: nil,
},
{
name: "subsequent machine login, no milestone",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
),
),
},
args: args{
ctx: ctx,
clientID: "client",
isHuman: false,
},
wantErr: nil,
},
{
name: "auth on app",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
),
),
},
args: args{
ctx: ctx,
clientID: "client",
isHuman: true,
},
wantEvents: []eventstore.Command{
milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication),
},
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore(t),
caches: &Caches{
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
},
}
s := &OIDCSessionEvents{
commands: c,
}
postCommit, err := s.SetMilestones(tt.args.ctx, tt.args.clientID, tt.args.isHuman)
postCommit(tt.args.ctx)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantEvents, s.events)
})
}
}
func TestCommands_projectCreatedMilestone(t *testing.T) {
ctx := authz.WithInstanceID(context.Background(), "instanceID")
systemCtx := authz.SetCtxData(ctx, authz.CtxData{
SystemMemberships: authz.Memberships{
&authz.Membership{
MemberType: authz.MemberTypeSystem,
},
},
})
aggregate := milestone.NewAggregate(ctx)
type fields struct {
eventstore func(*testing.T) *eventstore.Eventstore
}
type args struct {
ctx context.Context
}
tests := []struct {
name string
fields fields
args args
wantEvents []eventstore.Command
wantErr error
}{
{
name: "system user",
fields: fields{
eventstore: expectEventstore(),
},
args: args{
ctx: systemCtx,
},
wantErr: nil,
},
{
name: "get error",
fields: fields{
eventstore: expectEventstore(
expectFilterError(io.ErrClosedPipe),
),
},
args: args{
ctx: ctx,
},
wantErr: io.ErrClosedPipe,
},
{
name: "milestone already reached",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.ProjectCreated)),
),
),
},
args: args{
ctx: ctx,
},
wantErr: nil,
},
{
name: "milestone reached event",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
),
},
args: args{
ctx: ctx,
},
wantEvents: []eventstore.Command{
milestone.NewReachedEvent(ctx, aggregate, milestone.ProjectCreated),
},
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore(t),
caches: &Caches{
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
},
}
var cmds []eventstore.Command
postCommit, err := c.projectCreatedMilestone(tt.args.ctx, &cmds)
postCommit(tt.args.ctx)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantEvents, cmds)
})
}
}
func TestCommands_applicationCreatedMilestone(t *testing.T) {
ctx := authz.WithInstanceID(context.Background(), "instanceID")
systemCtx := authz.SetCtxData(ctx, authz.CtxData{
SystemMemberships: authz.Memberships{
&authz.Membership{
MemberType: authz.MemberTypeSystem,
},
},
})
aggregate := milestone.NewAggregate(ctx)
type fields struct {
eventstore func(*testing.T) *eventstore.Eventstore
}
type args struct {
ctx context.Context
}
tests := []struct {
name string
fields fields
args args
wantEvents []eventstore.Command
wantErr error
}{
{
name: "system user",
fields: fields{
eventstore: expectEventstore(),
},
args: args{
ctx: systemCtx,
},
wantErr: nil,
},
{
name: "get error",
fields: fields{
eventstore: expectEventstore(
expectFilterError(io.ErrClosedPipe),
),
},
args: args{
ctx: ctx,
},
wantErr: io.ErrClosedPipe,
},
{
name: "milestone already reached",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.ApplicationCreated)),
),
),
},
args: args{
ctx: ctx,
},
wantErr: nil,
},
{
name: "milestone reached event",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
),
},
args: args{
ctx: ctx,
},
wantEvents: []eventstore.Command{
milestone.NewReachedEvent(ctx, aggregate, milestone.ApplicationCreated),
},
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore(t),
caches: &Caches{
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
},
}
var cmds []eventstore.Command
postCommit, err := c.applicationCreatedMilestone(tt.args.ctx, &cmds)
postCommit(tt.args.ctx)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantEvents, cmds)
})
}
}
func (c *Commands) setMilestonesCompletedForTest(instanceID string) {
c.milestonesCompleted.Store(instanceID, struct{}{})
}

View File

@@ -71,7 +71,8 @@ func (c *Commands) CreateOIDCSessionFromAuthRequest(ctx context.Context, authReq
return nil, "", zerrors.ThrowPreconditionFailed(nil, "COMMAND-Iung5", "Errors.AuthRequest.NoCode")
}
sessionModel := NewSessionWriteModel(authReqModel.SessionID, authz.GetInstance(ctx).InstanceID())
instanceID := authz.GetInstance(ctx).InstanceID()
sessionModel := NewSessionWriteModel(authReqModel.SessionID, instanceID)
err = c.eventstore.FilterToQueryReducer(ctx, sessionModel)
if err != nil {
return nil, "", err
@@ -118,8 +119,15 @@ func (c *Commands) CreateOIDCSessionFromAuthRequest(ctx context.Context, authReq
}
}
cmd.SetAuthRequestSuccessful(ctx, authReqModel.aggregate)
session, err = cmd.PushEvents(ctx)
return session, authReqModel.State, err
postCommit, err := cmd.SetMilestones(ctx, authReqModel.ClientID, true)
if err != nil {
return nil, "", err
}
if session, err = cmd.PushEvents(ctx); err != nil {
return nil, "", err
}
postCommit(ctx)
return session, authReqModel.State, nil
}
func (c *Commands) CreateOIDCSession(ctx context.Context,
@@ -161,7 +169,15 @@ func (c *Commands) CreateOIDCSession(ctx context.Context,
return nil, err
}
}
return cmd.PushEvents(ctx)
postCommit, err := cmd.SetMilestones(ctx, clientID, sessionID != "")
if err != nil {
return nil, err
}
if session, err = cmd.PushEvents(ctx); err != nil {
return nil, err
}
postCommit(ctx)
return session, nil
}
type RefreshTokenComplianceChecker func(ctx context.Context, wm *OIDCSessionWriteModel, requestedScope []string) (scope []string, err error)
@@ -283,7 +299,7 @@ func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, userID, resource
}
sessionID = IDPrefixV2 + sessionID
return &OIDCSessionEvents{
eventstore: c.eventstore,
commands: c,
idGenerator: c.idGenerator,
encryptionAlg: c.keyAlgorithm,
events: pending,
@@ -341,7 +357,7 @@ func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, refreshToken
return nil, err
}
return &OIDCSessionEvents{
eventstore: c.eventstore,
commands: c,
idGenerator: c.idGenerator,
encryptionAlg: c.keyAlgorithm,
oidcSessionWriteModel: sessionWriteModel,
@@ -352,7 +368,7 @@ func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, refreshToken
}
type OIDCSessionEvents struct {
eventstore *eventstore.Eventstore
commands *Commands
idGenerator id.Generator
encryptionAlg crypto.EncryptionAlgorithm
events []eventstore.Command
@@ -467,7 +483,7 @@ func (c *OIDCSessionEvents) generateRefreshToken(userID string) (refreshTokenID,
}
func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (*OIDCSession, error) {
pushedEvents, err := c.eventstore.Push(ctx, c.events...)
pushedEvents, err := c.commands.eventstore.Push(ctx, c.events...)
if err != nil {
return nil, err
}
@@ -496,7 +512,7 @@ func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (*OIDCSession, error
// we need to use `-` as a delimiter because the OIDC library uses `:` and will check for a length of 2 parts
session.TokenID = c.oidcSessionWriteModel.AggregateID + TokenDelimiter + c.accessTokenID
}
activity.Trigger(ctx, c.oidcSessionWriteModel.UserResourceOwner, c.oidcSessionWriteModel.UserID, tokenReasonToActivityMethodType(c.oidcSessionWriteModel.AccessTokenReason), c.eventstore.FilterToQueryReducer)
activity.Trigger(ctx, c.oidcSessionWriteModel.UserResourceOwner, c.oidcSessionWriteModel.UserID, tokenReasonToActivityMethodType(c.oidcSessionWriteModel.AccessTokenReason), c.commands.eventstore.FilterToQueryReducer)
return session, nil
}

View File

@@ -70,7 +70,7 @@ func TestCommands_CreateOIDCSessionFromAuthRequest(t *testing.T) {
eventstore: expectEventstore(),
},
args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
authRequestID: "",
complianceCheck: mockAuthRequestComplianceChecker(nil),
},
@@ -86,7 +86,7 @@ func TestCommands_CreateOIDCSessionFromAuthRequest(t *testing.T) {
),
},
args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
authRequestID: "V2_authRequestID",
complianceCheck: mockAuthRequestComplianceChecker(nil),
},
@@ -102,7 +102,7 @@ func TestCommands_CreateOIDCSessionFromAuthRequest(t *testing.T) {
),
},
args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
authRequestID: "V2_authRequestID",
complianceCheck: mockAuthRequestComplianceChecker(nil),
},
@@ -706,6 +706,7 @@ func TestCommands_CreateOIDCSessionFromAuthRequest(t *testing.T) {
defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime,
keyAlgorithm: tt.fields.keyAlgorithm,
}
c.setMilestonesCompletedForTest("instanceID")
gotSession, gotState, err := c.CreateOIDCSessionFromAuthRequest(tt.args.ctx, tt.args.authRequestID, tt.args.complianceCheck, tt.args.needRefreshToken)
require.ErrorIs(t, err, tt.res.err)
@@ -762,7 +763,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
userID: "userID",
resourceOwner: "orgID",
clientID: "clientID",
@@ -818,7 +819,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
userID: "userID",
resourceOwner: "org1",
clientID: "clientID",
@@ -892,7 +893,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
userID: "userID",
resourceOwner: "org1",
clientID: "clientID",
@@ -1089,7 +1090,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
userID: "userID",
resourceOwner: "org1",
clientID: "clientID",
@@ -1186,7 +1187,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
userID: "userID",
resourceOwner: "org1",
clientID: "clientID",
@@ -1266,7 +1267,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
}),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
userID: "userID",
resourceOwner: "org1",
clientID: "clientID",
@@ -1347,7 +1348,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
}),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
userID: "userID",
resourceOwner: "org1",
clientID: "clientID",
@@ -1406,6 +1407,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
keyAlgorithm: tt.fields.keyAlgorithm,
checkPermission: tt.fields.checkPermission,
}
c.setMilestonesCompletedForTest("instanceID")
got, err := c.CreateOIDCSession(tt.args.ctx,
tt.args.userID,
tt.args.resourceOwner,

View File

@@ -34,7 +34,11 @@ func (c *Commands) AddProjectWithID(ctx context.Context, project *domain.Project
if existingProject.State != domain.ProjectStateUnspecified {
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-opamwu", "Errors.Project.AlreadyExisting")
}
return c.addProjectWithID(ctx, project, resourceOwner, projectID)
project, err = c.addProjectWithID(ctx, project, resourceOwner, projectID)
if err != nil {
return nil, err
}
return project, nil
}
func (c *Commands) AddProject(ctx context.Context, project *domain.Project, resourceOwner, ownerUserID string) (_ *domain.Project, err error) {
@@ -53,7 +57,11 @@ func (c *Commands) AddProject(ctx context.Context, project *domain.Project, reso
return nil, err
}
return c.addProjectWithIDWithOwner(ctx, project, resourceOwner, ownerUserID, projectID)
project, err = c.addProjectWithIDWithOwner(ctx, project, resourceOwner, ownerUserID, projectID)
if err != nil {
return nil, err
}
return project, nil
}
func (c *Commands) addProjectWithID(ctx context.Context, projectAdd *domain.Project, resourceOwner, projectID string) (_ *domain.Project, err error) {
@@ -71,11 +79,15 @@ func (c *Commands) addProjectWithID(ctx context.Context, projectAdd *domain.Proj
projectAdd.HasProjectCheck,
projectAdd.PrivateLabelingSetting),
}
postCommit, err := c.projectCreatedMilestone(ctx, &events)
if err != nil {
return nil, err
}
pushedEvents, err := c.eventstore.Push(ctx, events...)
if err != nil {
return nil, err
}
postCommit(ctx)
err = AppendAndReduce(addedProject, pushedEvents...)
if err != nil {
return nil, err
@@ -103,11 +115,15 @@ func (c *Commands) addProjectWithIDWithOwner(ctx context.Context, projectAdd *do
projectAdd.PrivateLabelingSetting),
project.NewProjectMemberAddedEvent(ctx, projectAgg, ownerUserID, projectRole),
}
postCommit, err := c.projectCreatedMilestone(ctx, &events)
if err != nil {
return nil, err
}
pushedEvents, err := c.eventstore.Push(ctx, events...)
if err != nil {
return nil, err
}
postCommit(ctx)
err = AppendAndReduce(addedProject, pushedEvents...)
if err != nil {
return nil, err

View File

@@ -202,10 +202,15 @@ func (c *Commands) addOIDCApplicationWithID(ctx context.Context, oidcApp *domain
))
addedApplication.AppID = oidcApp.AppID
postCommit, err := c.applicationCreatedMilestone(ctx, &events)
if err != nil {
return nil, err
}
pushedEvents, err := c.eventstore.Push(ctx, events...)
if err != nil {
return nil, err
}
postCommit(ctx)
err = AppendAndReduce(addedApplication, pushedEvents...)
if err != nil {
return nil, err

View File

@@ -11,6 +11,7 @@ import (
"github.com/zitadel/passwap"
"github.com/zitadel/passwap/bcrypt"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/command/preparation"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
@@ -418,7 +419,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
eventstore: expectEventstore(),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcApp: &domain.OIDCApp{},
resourceOwner: "org1",
},
@@ -434,7 +435,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcApp: &domain.OIDCApp{
ObjectRoot: models.ObjectRoot{
AggregateID: "project1",
@@ -463,7 +464,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcApp: &domain.OIDCApp{
ObjectRoot: models.ObjectRoot{
AggregateID: "project1",
@@ -521,7 +522,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "app1", "client1"),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcApp: &domain.OIDCApp{
ObjectRoot: models.ObjectRoot{
AggregateID: "project1",
@@ -619,7 +620,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "app1", "client1"),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
oidcApp: &domain.OIDCApp{
ObjectRoot: models.ObjectRoot{
AggregateID: "project1",
@@ -676,7 +677,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &Commands{
c := &Commands{
eventstore: tt.fields.eventstore(t),
idGenerator: tt.fields.idGenerator,
newHashedSecret: mockHashedSecret("secret"),
@@ -684,7 +685,8 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) {
ClientSecret: emptyConfig,
},
}
got, err := r.AddOIDCApplication(tt.args.ctx, tt.args.oidcApp, tt.args.resourceOwner)
c.setMilestonesCompletedForTest("instanceID")
got, err := c.AddOIDCApplication(tt.args.ctx, tt.args.oidcApp, tt.args.resourceOwner)
if tt.res.err == nil {
assert.NoError(t, err)
}

View File

@@ -28,10 +28,15 @@ func (c *Commands) AddSAMLApplication(ctx context.Context, application *domain.S
return nil, err
}
addedApplication.AppID = application.AppID
postCommit, err := c.applicationCreatedMilestone(ctx, &events)
if err != nil {
return nil, err
}
pushedEvents, err := c.eventstore.Push(ctx, events...)
if err != nil {
return nil, err
}
postCommit(ctx)
err = AppendAndReduce(addedApplication, pushedEvents...)
if err != nil {
return nil, err

View File

@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
@@ -76,7 +77,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
samlApp: &domain.SAMLApp{},
resourceOwner: "org1",
},
@@ -93,7 +94,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
samlApp: &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{
AggregateID: "project1",
@@ -123,7 +124,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
samlApp: &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{
AggregateID: "project1",
@@ -154,7 +155,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
idGenerator: id_mock.NewIDGeneratorExpectIDs(t),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
samlApp: &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{
AggregateID: "project1",
@@ -201,7 +202,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "app1"),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
samlApp: &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{
AggregateID: "project1",
@@ -260,7 +261,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
httpClient: newTestClient(200, testMetadata),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
samlApp: &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{
AggregateID: "project1",
@@ -305,7 +306,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
httpClient: newTestClient(http.StatusNotFound, nil),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
samlApp: &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{
AggregateID: "project1",
@@ -325,13 +326,13 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &Commands{
c := &Commands{
eventstore: tt.fields.eventstore,
idGenerator: tt.fields.idGenerator,
httpClient: tt.fields.httpClient,
}
got, err := r.AddSAMLApplication(tt.args.ctx, tt.args.samlApp, tt.args.resourceOwner)
c.setMilestonesCompletedForTest("instanceID")
got, err := c.AddSAMLApplication(tt.args.ctx, tt.args.samlApp, tt.args.resourceOwner)
if tt.res.err == nil {
assert.NoError(t, err)
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
@@ -44,7 +45,7 @@ func TestCommandSide_AddProject(t *testing.T) {
),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
project: &domain.Project{},
resourceOwner: "org1",
},
@@ -60,7 +61,7 @@ func TestCommandSide_AddProject(t *testing.T) {
),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
project: &domain.Project{
Name: "project",
ProjectRoleAssertion: true,
@@ -121,7 +122,7 @@ func TestCommandSide_AddProject(t *testing.T) {
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "project1"),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
project: &domain.Project{
Name: "project",
ProjectRoleAssertion: true,
@@ -159,7 +160,7 @@ func TestCommandSide_AddProject(t *testing.T) {
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "project1"),
},
args: args{
ctx: context.Background(),
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
project: &domain.Project{
Name: "project",
ProjectRoleAssertion: true,
@@ -187,11 +188,12 @@ func TestCommandSide_AddProject(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &Commands{
c := &Commands{
eventstore: tt.fields.eventstore,
idGenerator: tt.fields.idGenerator,
}
got, err := r.AddProject(tt.args.ctx, tt.args.project, tt.args.resourceOwner, tt.args.ownerID)
c.setMilestonesCompletedForTest("instanceID")
got, err := c.AddProject(tt.args.ctx, tt.args.project, tt.args.resourceOwner, tt.args.ownerID)
if tt.res.err == nil {
assert.NoError(t, err)
}