fix(saml): Push AuthenticationSucceededOnApplication milestone for SAML sessions (#10263)

# Which Problems Are Solved

The SAML session (v2 login) currently does not push a
`AuthenticationSucceededOnApplication` milestone upon successful SAML
login for the first time. The changes in this PR address this issue.

# How the Problems Are Solved

Add a new function to set the appropriate milestone, and call this
function after a successful SAML request.

# Additional Changes

N/A

# Additional Context

- Closes #9592

---------

Co-authored-by: Stefan Benz <46600784+stebenz@users.noreply.github.com>
This commit is contained in:
Gayathri Vijayan
2025-07-15 18:03:47 +02:00
committed by GitHub
parent e1f112d59b
commit 6d11145c77
4 changed files with 114 additions and 0 deletions

View File

@@ -133,6 +133,30 @@ func (s *OIDCSessionEvents) SetMilestones(ctx context.Context, clientID string,
return postCommit, nil
}
func (s *SAMLSessionEvents) SetMilestones(ctx context.Context) (postCommit func(ctx context.Context), err error) {
postCommit = func(ctx context.Context) {}
milestones, err := s.commands.GetMilestonesReached(ctx)
if err != nil {
return postCommit, err
}
instance := authz.GetInstance(ctx)
aggregate := milestone.NewAggregate(ctx)
var invalidate bool
if !milestones.AuthenticationSucceededOnInstance {
s.events = append(s.events, milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance))
invalidate = true
}
if !milestones.AuthenticationSucceededOnApplication {
s.events = append(s.events, milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication))
invalidate = true
}
if invalidate {
postCommit = s.commands.invalidateMilestoneCachePostCommit(instance.InstanceID())
}
return postCommit, nil
}
func (c *Commands) projectCreatedMilestone(ctx context.Context, cmds *[]eventstore.Command) (postCommit func(ctx context.Context), err error) {
postCommit = func(ctx context.Context) {}
if isSystemUser(ctx) {

View File

@@ -627,3 +627,84 @@ func TestCommands_applicationCreatedMilestone(t *testing.T) {
func (c *Commands) setMilestonesCompletedForTest(instanceID string) {
c.milestonesCompleted.Store(instanceID, struct{}{})
}
func TestSAMLSessionEvents_SetMilestones(t *testing.T) {
ctx := authz.WithInstanceID(context.Background(), "instanceID")
aggregate := milestone.NewAggregate(ctx)
type fields struct {
eventstore func(*testing.T) *eventstore.Eventstore
}
tests := []struct {
name string
fields fields
wantEvents []eventstore.Command
wantErr error
}{
{
name: "get error",
fields: fields{
eventstore: expectEventstore(
expectFilterError(io.ErrClosedPipe),
),
},
wantErr: io.ErrClosedPipe,
},
{
name: "auth on instance, auth on application",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
),
},
wantEvents: []eventstore.Command{
milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance),
milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication),
},
wantErr: nil,
},
{
name: "auth on app with a previous auth on instance",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
),
),
},
wantEvents: []eventstore.Command{
milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication),
},
wantErr: nil,
},
{
name: "milestones already reached",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)),
eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication)),
),
),
},
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore(t),
caches: &Caches{
milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](),
},
}
s := &SAMLSessionEvents{
commands: c,
}
postCommit, err := s.SetMilestones(ctx)
postCommit(ctx)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantEvents, s.events)
})
}
}

View File

@@ -80,7 +80,15 @@ func (c *Commands) CreateSAMLSessionFromSAMLRequest(ctx context.Context, samlReq
return err
}
cmd.SetSAMLRequestSuccessful(ctx, samlReqModel.aggregate)
postCommit, err := cmd.SetMilestones(ctx)
if err != nil {
return err
}
_, err = cmd.PushEvents(ctx)
if err != nil {
return err
}
postCommit(ctx)
return err
}

View File

@@ -334,6 +334,7 @@ func TestCommands_CreateSAMLSessionFromSAMLRequest(t *testing.T) {
idGenerator: tt.fields.idGenerator,
keyAlgorithm: tt.fields.keyAlgorithm,
}
c.setMilestonesCompletedForTest("instanceID")
err := c.CreateSAMLSessionFromSAMLRequest(tt.args.ctx, tt.args.samlRequestID, tt.args.complianceCheck, tt.args.samlResponseID, tt.args.samlResponseLifetime)
require.ErrorIs(t, err, tt.res.err)
})