From 6d11145c779034789083d592051128d9b75916e8 Mon Sep 17 00:00:00 2001 From: Gayathri Vijayan <66356931+grvijayan@users.noreply.github.com> Date: Tue, 15 Jul 2025 18:03:47 +0200 Subject: [PATCH] fix(saml): Push AuthenticationSucceededOnApplication milestone for SAML sessions (#10263) # Which Problems Are Solved The SAML session (v2 login) currently does not push a `AuthenticationSucceededOnApplication` milestone upon successful SAML login for the first time. The changes in this PR address this issue. # How the Problems Are Solved Add a new function to set the appropriate milestone, and call this function after a successful SAML request. # Additional Changes N/A # Additional Context - Closes #9592 --------- Co-authored-by: Stefan Benz <46600784+stebenz@users.noreply.github.com> --- internal/command/milestone.go | 24 ++++++++ internal/command/milestone_test.go | 81 +++++++++++++++++++++++++++ internal/command/saml_session.go | 8 +++ internal/command/saml_session_test.go | 1 + 4 files changed, 114 insertions(+) diff --git a/internal/command/milestone.go b/internal/command/milestone.go index e2f4fdc9de..9ef3393325 100644 --- a/internal/command/milestone.go +++ b/internal/command/milestone.go @@ -133,6 +133,30 @@ func (s *OIDCSessionEvents) SetMilestones(ctx context.Context, clientID string, return postCommit, nil } +func (s *SAMLSessionEvents) SetMilestones(ctx context.Context) (postCommit func(ctx context.Context), err error) { + postCommit = func(ctx context.Context) {} + milestones, err := s.commands.GetMilestonesReached(ctx) + if err != nil { + return postCommit, err + } + + instance := authz.GetInstance(ctx) + aggregate := milestone.NewAggregate(ctx) + var invalidate bool + if !milestones.AuthenticationSucceededOnInstance { + s.events = append(s.events, milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)) + invalidate = true + } + if !milestones.AuthenticationSucceededOnApplication { + s.events = append(s.events, milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication)) + invalidate = true + } + if invalidate { + postCommit = s.commands.invalidateMilestoneCachePostCommit(instance.InstanceID()) + } + return postCommit, nil +} + func (c *Commands) projectCreatedMilestone(ctx context.Context, cmds *[]eventstore.Command) (postCommit func(ctx context.Context), err error) { postCommit = func(ctx context.Context) {} if isSystemUser(ctx) { diff --git a/internal/command/milestone_test.go b/internal/command/milestone_test.go index 3c4bffc704..d70f12f9bc 100644 --- a/internal/command/milestone_test.go +++ b/internal/command/milestone_test.go @@ -627,3 +627,84 @@ func TestCommands_applicationCreatedMilestone(t *testing.T) { func (c *Commands) setMilestonesCompletedForTest(instanceID string) { c.milestonesCompleted.Store(instanceID, struct{}{}) } + +func TestSAMLSessionEvents_SetMilestones(t *testing.T) { + ctx := authz.WithInstanceID(context.Background(), "instanceID") + aggregate := milestone.NewAggregate(ctx) + + type fields struct { + eventstore func(*testing.T) *eventstore.Eventstore + } + tests := []struct { + name string + fields fields + wantEvents []eventstore.Command + wantErr error + }{ + { + name: "get error", + fields: fields{ + eventstore: expectEventstore( + expectFilterError(io.ErrClosedPipe), + ), + }, + wantErr: io.ErrClosedPipe, + }, + { + name: "auth on instance, auth on application", + fields: fields{ + eventstore: expectEventstore( + expectFilter(), + ), + }, + wantEvents: []eventstore.Command{ + milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance), + milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication), + }, + wantErr: nil, + }, + { + name: "auth on app with a previous auth on instance", + fields: fields{ + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)), + ), + ), + }, + wantEvents: []eventstore.Command{ + milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication), + }, + wantErr: nil, + }, + { + name: "milestones already reached", + fields: fields{ + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnInstance)), + eventFromEventPusher(milestone.NewReachedEvent(ctx, aggregate, milestone.AuthenticationSucceededOnApplication)), + ), + ), + }, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + eventstore: tt.fields.eventstore(t), + caches: &Caches{ + milestones: noop.NewCache[milestoneIndex, string, *MilestonesReached](), + }, + } + s := &SAMLSessionEvents{ + commands: c, + } + postCommit, err := s.SetMilestones(ctx) + postCommit(ctx) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.wantEvents, s.events) + }) + } +} diff --git a/internal/command/saml_session.go b/internal/command/saml_session.go index 6e0c37af9e..6329f35c5f 100644 --- a/internal/command/saml_session.go +++ b/internal/command/saml_session.go @@ -80,7 +80,15 @@ func (c *Commands) CreateSAMLSessionFromSAMLRequest(ctx context.Context, samlReq return err } cmd.SetSAMLRequestSuccessful(ctx, samlReqModel.aggregate) + postCommit, err := cmd.SetMilestones(ctx) + if err != nil { + return err + } _, err = cmd.PushEvents(ctx) + if err != nil { + return err + } + postCommit(ctx) return err } diff --git a/internal/command/saml_session_test.go b/internal/command/saml_session_test.go index 4781381cc4..15445e9e5c 100644 --- a/internal/command/saml_session_test.go +++ b/internal/command/saml_session_test.go @@ -334,6 +334,7 @@ func TestCommands_CreateSAMLSessionFromSAMLRequest(t *testing.T) { idGenerator: tt.fields.idGenerator, keyAlgorithm: tt.fields.keyAlgorithm, } + c.setMilestonesCompletedForTest("instanceID") err := c.CreateSAMLSessionFromSAMLRequest(tt.args.ctx, tt.args.samlRequestID, tt.args.complianceCheck, tt.args.samlResponseID, tt.args.samlResponseLifetime) require.ErrorIs(t, err, tt.res.err) })