mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 17:27:31 +00:00
fix(saml): Push AuthenticationSucceededOnApplication milestone for SAML sessions (#10263)
# Which Problems Are Solved The SAML session (v2 login) currently does not push a `AuthenticationSucceededOnApplication` milestone upon successful SAML login for the first time. The changes in this PR address this issue. # How the Problems Are Solved Add a new function to set the appropriate milestone, and call this function after a successful SAML request. # Additional Changes N/A # Additional Context - Closes #9592 --------- Co-authored-by: Stefan Benz <46600784+stebenz@users.noreply.github.com>
This commit is contained in:
@@ -133,6 +133,30 @@ func (s *OIDCSessionEvents) SetMilestones(ctx context.Context, clientID string,
|
|||||||
return postCommit, nil
|
return postCommit, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SAMLSessionEvents) SetMilestones(ctx context.Context) (postCommit func(ctx context.Context), err error) {
|
||||||
|
postCommit = func(ctx context.Context) {}
|
||||||
|
milestones, err := s.commands.GetMilestonesReached(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return postCommit, err
|
||||||
|
}
|
||||||
|
|
||||||
|
instance := authz.GetInstance(ctx)
|
||||||
|
aggregate := milestone.NewAggregate(ctx)
|
||||||
|
var invalidate bool
|
||||||
|
if !milestones.AuthenticationSucceededOnInstance {
|
||||||
|
s.events = append(s.events, milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance))
|
||||||
|
invalidate = true
|
||||||
|
}
|
||||||
|
if !milestones.AuthenticationSucceededOnApplication {
|
||||||
|
s.events = append(s.events, milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication))
|
||||||
|
invalidate = true
|
||||||
|
}
|
||||||
|
if invalidate {
|
||||||
|
postCommit = s.commands.invalidateMilestoneCachePostCommit(instance.InstanceID())
|
||||||
|
}
|
||||||
|
return postCommit, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Commands) projectCreatedMilestone(ctx context.Context, cmds *[]eventstore.Command) (postCommit func(ctx context.Context), err error) {
|
func (c *Commands) projectCreatedMilestone(ctx context.Context, cmds *[]eventstore.Command) (postCommit func(ctx context.Context), err error) {
|
||||||
postCommit = func(ctx context.Context) {}
|
postCommit = func(ctx context.Context) {}
|
||||||
if isSystemUser(ctx) {
|
if isSystemUser(ctx) {
|
||||||
|
@@ -627,3 +627,84 @@ func TestCommands_applicationCreatedMilestone(t *testing.T) {
|
|||||||
func (c *Commands) setMilestonesCompletedForTest(instanceID string) {
|
func (c *Commands) setMilestonesCompletedForTest(instanceID string) {
|
||||||
c.milestonesCompleted.Store(instanceID, struct{}{})
|
c.milestonesCompleted.Store(instanceID, struct{}{})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSAMLSessionEvents_SetMilestones(t *testing.T) {
|
||||||
|
ctx := authz.WithInstanceID(context.Background(), "instanceID")
|
||||||
|
aggregate := milestone.NewAggregate(ctx)
|
||||||
|
|
||||||
|
type fields struct {
|
||||||
|
eventstore func(*testing.T) *eventstore.Eventstore
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
wantEvents []eventstore.Command
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "get error",
|
||||||
|
fields: fields{
|
||||||
|
eventstore: expectEventstore(
|
||||||
|
expectFilterError(io.ErrClosedPipe),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
wantErr: io.ErrClosedPipe,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "auth on instance, auth on application",
|
||||||
|
fields: fields{
|
||||||
|
eventstore: expectEventstore(
|
||||||
|
expectFilter(),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
wantEvents: []eventstore.Command{
|
||||||
|
milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance),
|
||||||
|
milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication),
|
||||||
|
},
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "auth on app with a previous auth on instance",
|
||||||
|
fields: fields{
|
||||||
|
eventstore: expectEventstore(
|
||||||
|
expectFilter(
|
||||||
|
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
wantEvents: []eventstore.Command{
|
||||||
|
milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication),
|
||||||
|
},
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "milestones already reached",
|
||||||
|
fields: fields{
|
||||||
|
eventstore: expectEventstore(
|
||||||
|
expectFilter(
|
||||||
|
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
|
||||||
|
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication)),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
wantErr: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := &Commands{
|
||||||
|
eventstore: tt.fields.eventstore(t),
|
||||||
|
caches: &Caches{
|
||||||
|
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
s := &SAMLSessionEvents{
|
||||||
|
commands: c,
|
||||||
|
}
|
||||||
|
postCommit, err := s.SetMilestones(ctx)
|
||||||
|
postCommit(ctx)
|
||||||
|
require.ErrorIs(t, err, tt.wantErr)
|
||||||
|
assert.Equal(t, tt.wantEvents, s.events)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -80,7 +80,15 @@ func (c *Commands) CreateSAMLSessionFromSAMLRequest(ctx context.Context, samlReq
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
cmd.SetSAMLRequestSuccessful(ctx, samlReqModel.aggregate)
|
cmd.SetSAMLRequestSuccessful(ctx, samlReqModel.aggregate)
|
||||||
|
postCommit, err := cmd.SetMilestones(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
_, err = cmd.PushEvents(ctx)
|
_, err = cmd.PushEvents(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
postCommit(ctx)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -334,6 +334,7 @@ func TestCommands_CreateSAMLSessionFromSAMLRequest(t *testing.T) {
|
|||||||
idGenerator: tt.fields.idGenerator,
|
idGenerator: tt.fields.idGenerator,
|
||||||
keyAlgorithm: tt.fields.keyAlgorithm,
|
keyAlgorithm: tt.fields.keyAlgorithm,
|
||||||
}
|
}
|
||||||
|
c.setMilestonesCompletedForTest("instanceID")
|
||||||
err := c.CreateSAMLSessionFromSAMLRequest(tt.args.ctx, tt.args.samlRequestID, tt.args.complianceCheck, tt.args.samlResponseID, tt.args.samlResponseLifetime)
|
err := c.CreateSAMLSessionFromSAMLRequest(tt.args.ctx, tt.args.samlRequestID, tt.args.complianceCheck, tt.args.samlResponseID, tt.args.samlResponseLifetime)
|
||||||
require.ErrorIs(t, err, tt.res.err)
|
require.ErrorIs(t, err, tt.res.err)
|
||||||
})
|
})
|
||||||
|
Reference in New Issue
Block a user