mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 15:17:33 +00:00
fix(saml): Push AuthenticationSucceededOnApplication milestone for SAML sessions (#10263)
# Which Problems Are Solved The SAML session (v2 login) currently does not push a `AuthenticationSucceededOnApplication` milestone upon successful SAML login for the first time. The changes in this PR address this issue. # How the Problems Are Solved Add a new function to set the appropriate milestone, and call this function after a successful SAML request. # Additional Changes N/A # Additional Context - Closes #9592 --------- Co-authored-by: Stefan Benz <46600784+stebenz@users.noreply.github.com>
This commit is contained in:
@@ -133,6 +133,30 @@ func (s *OIDCSessionEvents) SetMilestones(ctx context.Context, clientID string,
|
||||
return postCommit, nil
|
||||
}
|
||||
|
||||
func (s *SAMLSessionEvents) SetMilestones(ctx context.Context) (postCommit func(ctx context.Context), err error) {
|
||||
postCommit = func(ctx context.Context) {}
|
||||
milestones, err := s.commands.GetMilestonesReached(ctx)
|
||||
if err != nil {
|
||||
return postCommit, err
|
||||
}
|
||||
|
||||
instance := authz.GetInstance(ctx)
|
||||
aggregate := milestone.NewAggregate(ctx)
|
||||
var invalidate bool
|
||||
if !milestones.AuthenticationSucceededOnInstance {
|
||||
s.events = append(s.events, milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance))
|
||||
invalidate = true
|
||||
}
|
||||
if !milestones.AuthenticationSucceededOnApplication {
|
||||
s.events = append(s.events, milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication))
|
||||
invalidate = true
|
||||
}
|
||||
if invalidate {
|
||||
postCommit = s.commands.invalidateMilestoneCachePostCommit(instance.InstanceID())
|
||||
}
|
||||
return postCommit, nil
|
||||
}
|
||||
|
||||
func (c *Commands) projectCreatedMilestone(ctx context.Context, cmds *[]eventstore.Command) (postCommit func(ctx context.Context), err error) {
|
||||
postCommit = func(ctx context.Context) {}
|
||||
if isSystemUser(ctx) {
|
||||
|
@@ -627,3 +627,84 @@ func TestCommands_applicationCreatedMilestone(t *testing.T) {
|
||||
func (c *Commands) setMilestonesCompletedForTest(instanceID string) {
|
||||
c.milestonesCompleted.Store(instanceID, struct{}{})
|
||||
}
|
||||
|
||||
func TestSAMLSessionEvents_SetMilestones(t *testing.T) {
|
||||
ctx := authz.WithInstanceID(context.Background(), "instanceID")
|
||||
aggregate := milestone.NewAggregate(ctx)
|
||||
|
||||
type fields struct {
|
||||
eventstore func(*testing.T) *eventstore.Eventstore
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantEvents []eventstore.Command
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "get error",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilterError(io.ErrClosedPipe),
|
||||
),
|
||||
},
|
||||
wantErr: io.ErrClosedPipe,
|
||||
},
|
||||
{
|
||||
name: "auth on instance, auth on application",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
wantEvents: []eventstore.Command{
|
||||
milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance),
|
||||
milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication),
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "auth on app with a previous auth on instance",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
|
||||
),
|
||||
),
|
||||
},
|
||||
wantEvents: []eventstore.Command{
|
||||
milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication),
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "milestones already reached",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
|
||||
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication)),
|
||||
),
|
||||
),
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
caches: &Caches{
|
||||
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
|
||||
},
|
||||
}
|
||||
s := &SAMLSessionEvents{
|
||||
commands: c,
|
||||
}
|
||||
postCommit, err := s.SetMilestones(ctx)
|
||||
postCommit(ctx)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.wantEvents, s.events)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -80,7 +80,15 @@ func (c *Commands) CreateSAMLSessionFromSAMLRequest(ctx context.Context, samlReq
|
||||
return err
|
||||
}
|
||||
cmd.SetSAMLRequestSuccessful(ctx, samlReqModel.aggregate)
|
||||
postCommit, err := cmd.SetMilestones(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = cmd.PushEvents(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
postCommit(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@@ -334,6 +334,7 @@ func TestCommands_CreateSAMLSessionFromSAMLRequest(t *testing.T) {
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
keyAlgorithm: tt.fields.keyAlgorithm,
|
||||
}
|
||||
c.setMilestonesCompletedForTest("instanceID")
|
||||
err := c.CreateSAMLSessionFromSAMLRequest(tt.args.ctx, tt.args.samlRequestID, tt.args.complianceCheck, tt.args.samlResponseID, tt.args.samlResponseLifetime)
|
||||
require.ErrorIs(t, err, tt.res.err)
|
||||
})
|
||||
|
Reference in New Issue
Block a user