From ed697bbd69b7e9596e9cd53d8f37aad09403d87a Mon Sep 17 00:00:00 2001 From: Livio Spring Date: Tue, 11 Mar 2025 15:19:09 +0100 Subject: [PATCH] fix(OIDC): back channel logout work for custom UI (#9487) # Which Problems Are Solved When using a custom / new login UI and an OIDC application with registered BackChannelLogoutUI, no logout requests were sent to the URI when the user signed out. Additionally, as described in #9427, an error was logged: `level=error msg="event of type *session.TerminateEvent doesn't implement OriginEvent" caller="/home/runner/work/zitadel/zitadel/internal/notification/handlers/origin.go:24"` # How the Problems Are Solved - Properly pass `TriggerOrigin` information to session.TerminateEvent creation and implement `OriginEvent` interface. - Implemented `RegisterLogout` in `CreateOIDCSessionFromAuthRequest` and `CreateOIDCSessionFromDeviceAuth`, both used when interacting with the OIDC v2 API. - Both functions now receive the `BackChannelLogoutURI` of the client from the OIDC layer. # Additional Changes None # Additional Context - closes #9427 --- internal/api/oidc/auth_request.go | 1 + internal/api/oidc/token_code.go | 1 + internal/api/oidc/token_device.go | 2 +- internal/command/device_auth.go | 3 +- internal/command/device_auth_test.go | 126 +++++++++++++++++++- internal/command/oidc_session.go | 9 +- internal/command/oidc_session_test.go | 156 ++++++++++++++++++++++++- internal/repository/session/session.go | 7 +- 8 files changed, 293 insertions(+), 12 deletions(-) diff --git a/internal/api/oidc/auth_request.go b/internal/api/oidc/auth_request.go index 793001045c..d433603cd8 100644 --- a/internal/api/oidc/auth_request.go +++ b/internal/api/oidc/auth_request.go @@ -568,6 +568,7 @@ func (s *Server) CreateTokenCallbackURL(ctx context.Context, req op.AuthRequest) req.GetID(), implicitFlowComplianceChecker(), slices.Contains(client.GrantTypes(), oidc.GrantTypeRefreshToken), + client.client.BackChannelLogoutURI, ) if err != nil { return "", err diff --git a/internal/api/oidc/token_code.go b/internal/api/oidc/token_code.go index ee3585be69..033f2453b9 100644 --- a/internal/api/oidc/token_code.go +++ b/internal/api/oidc/token_code.go @@ -41,6 +41,7 @@ func (s *Server) CodeExchange(ctx context.Context, r *op.ClientRequest[oidc.Acce plainCode, codeExchangeComplianceChecker(client, r.Data), slices.Contains(client.GrantTypes(), oidc.GrantTypeRefreshToken), + client.client.BackChannelLogoutURI, ) } else { session, err = s.codeExchangeV1(ctx, client, r.Data, r.Data.Code) diff --git a/internal/api/oidc/token_device.go b/internal/api/oidc/token_device.go index 8e0f8dc993..8f42bb3ac4 100644 --- a/internal/api/oidc/token_device.go +++ b/internal/api/oidc/token_device.go @@ -25,7 +25,7 @@ func (s *Server) DeviceToken(ctx context.Context, r *op.ClientRequest[oidc.Devic if !ok { return nil, zerrors.ThrowInternal(nil, "OIDC-Ae2ph", "Error.Internal") } - session, err := s.command.CreateOIDCSessionFromDeviceAuth(ctx, r.Data.DeviceCode) + session, err := s.command.CreateOIDCSessionFromDeviceAuth(ctx, r.Data.DeviceCode, client.client.BackChannelLogoutURI) if err == nil { return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion)) } diff --git a/internal/command/device_auth.go b/internal/command/device_auth.go index d3588660be..ef6b069cc9 100644 --- a/internal/command/device_auth.go +++ b/internal/command/device_auth.go @@ -174,7 +174,7 @@ func (e DeviceAuthStateError) Error() string { // As devices can poll at various intervals, an explicit state takes precedence over expiry. // This is to prevent cases where users might approve or deny the authorization on time, but the next poll // happens after expiry. -func (c *Commands) CreateOIDCSessionFromDeviceAuth(ctx context.Context, deviceCode string) (_ *OIDCSession, err error) { +func (c *Commands) CreateOIDCSessionFromDeviceAuth(ctx context.Context, deviceCode, backChannelLogoutURI string) (_ *OIDCSession, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -219,6 +219,7 @@ func (c *Commands) CreateOIDCSessionFromDeviceAuth(ctx context.Context, deviceCo deviceAuthModel.PreferredLanguage, deviceAuthModel.UserAgent, ) + cmd.RegisterLogout(ctx, deviceAuthModel.SessionID, deviceAuthModel.UserID, deviceAuthModel.ClientID, backChannelLogoutURI) if err = cmd.AddAccessToken(ctx, deviceAuthModel.Scopes, deviceAuthModel.UserID, deviceAuthModel.UserOrgID, domain.TokenReasonAuthRequest, nil); err != nil { return nil, err } diff --git a/internal/command/device_auth_test.go b/internal/command/device_auth_test.go index 508ca10571..021ae25d36 100644 --- a/internal/command/device_auth_test.go +++ b/internal/command/device_auth_test.go @@ -19,11 +19,13 @@ import ( "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/feature" "github.com/zitadel/zitadel/internal/id" "github.com/zitadel/zitadel/internal/id/mock" "github.com/zitadel/zitadel/internal/repository/deviceauth" "github.com/zitadel/zitadel/internal/repository/oidcsession" "github.com/zitadel/zitadel/internal/repository/session" + "github.com/zitadel/zitadel/internal/repository/sessionlogout" "github.com/zitadel/zitadel/internal/repository/user" "github.com/zitadel/zitadel/internal/zerrors" ) @@ -704,8 +706,9 @@ func TestCommands_CreateOIDCSessionFromDeviceAuth(t *testing.T) { keyAlgorithm crypto.EncryptionAlgorithm } type args struct { - ctx context.Context - deviceCode string + ctx context.Context + deviceCode string + backChannelLogoutURI string } tests := []struct { name string @@ -724,6 +727,7 @@ func TestCommands_CreateOIDCSessionFromDeviceAuth(t *testing.T) { args: args{ ctx, "device1", + "", }, wantErr: io.ErrClosedPipe, }, @@ -748,6 +752,7 @@ func TestCommands_CreateOIDCSessionFromDeviceAuth(t *testing.T) { args: args{ ctx, "123", + "", }, wantErr: DeviceAuthStateError(domain.DeviceAuthStateInitiated), }, @@ -761,6 +766,7 @@ func TestCommands_CreateOIDCSessionFromDeviceAuth(t *testing.T) { args: args{ ctx, "123", + "", }, wantErr: zerrors.ThrowNotFound(nil, "COMMAND-ua1Vo", "Errors.DeviceAuth.NotFound"), }, @@ -789,6 +795,7 @@ func TestCommands_CreateOIDCSessionFromDeviceAuth(t *testing.T) { args: args{ ctx, "123", + "", }, wantErr: DeviceAuthStateError(domain.DeviceAuthStateExpired), }, @@ -820,6 +827,7 @@ func TestCommands_CreateOIDCSessionFromDeviceAuth(t *testing.T) { args: args{ ctx, "123", + "", }, wantErr: DeviceAuthStateError(domain.DeviceAuthStateExpired), }, @@ -851,6 +859,7 @@ func TestCommands_CreateOIDCSessionFromDeviceAuth(t *testing.T) { args: args{ ctx, "123", + "", }, wantErr: DeviceAuthStateError(domain.DeviceAuthStateDenied), }, @@ -888,6 +897,7 @@ func TestCommands_CreateOIDCSessionFromDeviceAuth(t *testing.T) { args: args{ ctx, "123", + "", }, wantErr: DeviceAuthStateError(domain.DeviceAuthStateDone), }, @@ -951,6 +961,7 @@ func TestCommands_CreateOIDCSessionFromDeviceAuth(t *testing.T) { args: args{ ctx, "123", + "", }, wantErr: zerrors.ThrowPreconditionFailed(nil, "OIDCS-kj3g2", "Errors.User.NotActive"), }, @@ -1030,6 +1041,114 @@ func TestCommands_CreateOIDCSessionFromDeviceAuth(t *testing.T) { args: args{ ctx, "123", + "", + }, + want: &OIDCSession{ + TokenID: "V2_oidcSessionID-at_accessTokenID", + ClientID: "clientID", + UserID: "userID", + Audience: []string{"audience"}, + Expiration: time.Time{}.Add(time.Hour), + Scope: []string{"openid", "offline_access"}, + AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, + AuthTime: testNow, + PreferredLanguage: &language.Afrikaans, + UserAgent: &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + Reason: domain.TokenReasonAuthRequest, + SessionID: "sessionID", + }, + }, + { + name: "approved with backChannelLogout (feature enabled), success", + fields: fields{ + eventstore: expectEventstore( + expectFilter( + eventFromEventPusherWithInstanceID( + "instance1", + deviceauth.NewAddedEvent( + ctx, + deviceauth.NewAggregate("123", "instance1"), + "clientID", "123", "456", time.Now().Add(-time.Minute), + []string{"openid", "offline_access"}, + []string{"audience"}, false, + ), + ), + eventFromEventPusherWithInstanceID( + "instance1", + deviceauth.NewApprovedEvent(ctx, + deviceauth.NewAggregate("123", "instance1"), + "userID", "org1", + []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, + testNow, &language.Afrikaans, &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + "sessionID", + ), + ), + ), + expectFilter( + user.NewHumanAddedEvent( + ctx, + &user.NewAggregate("userID", "org1").Aggregate, + "username", + "firstname", + "lastname", + "nickname", + "displayname", + language.English, + domain.GenderUnspecified, + "email", + false, + ), + ), + expectFilter(), // token lifetime + expectPush( + oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, + "userID", "org1", "sessionID", "clientID", []string{"audience"}, []string{"openid", "offline_access"}, + []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow, "", &language.Afrikaans, &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + ), + sessionlogout.NewBackChannelLogoutRegisteredEvent(context.Background(), + &sessionlogout.NewAggregate("sessionID", "instance1").Aggregate, + "V2_oidcSessionID", + "userID", + "clientID", + "backChannelLogoutURI", + ), + oidcsession.NewAccessTokenAddedEvent(context.Background(), + &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, + "at_accessTokenID", []string{"openid", "offline_access"}, time.Hour, domain.TokenReasonAuthRequest, nil, + ), + user.NewUserTokenV2AddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, "at_accessTokenID"), + deviceauth.NewDoneEvent(ctx, + deviceauth.NewAggregate("123", "instance1"), + ), + ), + ), + idGenerator: mock.NewIDGeneratorExpectIDs(t, "oidcSessionID", "accessTokenID"), + defaultAccessTokenLifetime: time.Hour, + defaultRefreshTokenLifetime: 7 * 24 * time.Hour, + defaultRefreshTokenIdleLifetime: 24 * time.Hour, + keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args: args{ + authz.WithFeatures(ctx, feature.Features{ + EnableBackChannelLogout: true, + }), + "123", + "backChannelLogoutURI", }, want: &OIDCSession{ TokenID: "V2_oidcSessionID-at_accessTokenID", @@ -1130,6 +1249,7 @@ func TestCommands_CreateOIDCSessionFromDeviceAuth(t *testing.T) { args: args{ ctx, "123", + "", }, want: &OIDCSession{ TokenID: "V2_oidcSessionID-at_accessTokenID", @@ -1163,7 +1283,7 @@ func TestCommands_CreateOIDCSessionFromDeviceAuth(t *testing.T) { defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime, keyAlgorithm: tt.fields.keyAlgorithm, } - got, err := c.CreateOIDCSessionFromDeviceAuth(tt.args.ctx, tt.args.deviceCode) + got, err := c.CreateOIDCSessionFromDeviceAuth(tt.args.ctx, tt.args.deviceCode, tt.args.backChannelLogoutURI) c.jobs.Wait() require.ErrorIs(t, err, tt.wantErr) diff --git a/internal/command/oidc_session.go b/internal/command/oidc_session.go index bea17986ea..492d89bc2d 100644 --- a/internal/command/oidc_session.go +++ b/internal/command/oidc_session.go @@ -55,7 +55,13 @@ type AuthRequestComplianceChecker func(context.Context, *AuthRequestWriteModel) // CreateOIDCSessionFromAuthRequest creates a new OIDC Session, creates an access token and refresh token. // It returns the access token id, expiration and the refresh token. // If the underlying [AuthRequest] is a OIDC Auth Code Flow, it will set the code as exchanged. -func (c *Commands) CreateOIDCSessionFromAuthRequest(ctx context.Context, authReqId string, complianceCheck AuthRequestComplianceChecker, needRefreshToken bool) (session *OIDCSession, state string, err error) { +func (c *Commands) CreateOIDCSessionFromAuthRequest( + ctx context.Context, + authReqId string, + complianceCheck AuthRequestComplianceChecker, + needRefreshToken bool, + backChannelLogoutURI string, +) (session *OIDCSession, state string, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -108,6 +114,7 @@ func (c *Commands) CreateOIDCSessionFromAuthRequest(ctx context.Context, authReq sessionModel.PreferredLanguage, sessionModel.UserAgent, ) + cmd.RegisterLogout(ctx, sessionModel.AggregateID, sessionModel.UserID, authReqModel.ClientID, backChannelLogoutURI) if authReqModel.ResponseType != domain.OIDCResponseTypeIDToken { if err = cmd.AddAccessToken(ctx, authReqModel.Scope, sessionModel.UserID, sessionModel.UserResourceOwner, domain.TokenReasonAuthRequest, nil); err != nil { diff --git a/internal/command/oidc_session_test.go b/internal/command/oidc_session_test.go index 18a115eb00..af1874a6bb 100644 --- a/internal/command/oidc_session_test.go +++ b/internal/command/oidc_session_test.go @@ -49,10 +49,11 @@ func TestCommands_CreateOIDCSessionFromAuthRequest(t *testing.T) { keyAlgorithm crypto.EncryptionAlgorithm } type args struct { - ctx context.Context - authRequestID string - complianceCheck AuthRequestComplianceChecker - needRefreshToken bool + ctx context.Context + authRequestID string + complianceCheck AuthRequestComplianceChecker + needRefreshToken bool + backChannelLogoutURI string } type res struct { session *OIDCSession @@ -438,6 +439,151 @@ func TestCommands_CreateOIDCSessionFromAuthRequest(t *testing.T) { state: "state", }, }, + { + "add successful, backChannelLogout (feature enabled)", + fields{ + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + authrequest.NewAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate, + "loginClient", + "clientID", + "redirectURI", + "state", + "nonce", + []string{"openid", "offline_access"}, + []string{"audience"}, + domain.OIDCResponseTypeCode, + domain.OIDCResponseModeQuery, + &domain.OIDCCodeChallenge{ + Challenge: "challenge", + Method: domain.CodeChallengeMethodS256, + }, + []domain.Prompt{domain.PromptNone}, + []string{"en", "de"}, + gu.Ptr(time.Duration(0)), + gu.Ptr("loginHint"), + gu.Ptr("hintUserID"), + true, + ), + ), + eventFromEventPusher( + authrequest.NewCodeAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate), + ), + eventFromEventPusher( + authrequest.NewSessionLinkedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate, + "sessionID", + "userID", + testNow, + []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, + ), + ), + ), + expectFilter( + eventFromEventPusher( + session.NewAddedEvent(context.Background(), + &session.NewAggregate("sessionID", "instance1").Aggregate, + &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + ), + ), + eventFromEventPusher( + session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, + "userID", "org1", testNow, &language.Afrikaans), + ), + eventFromEventPusher( + session.NewPasswordCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, + testNow), + ), + ), + expectFilter( + user.NewHumanAddedEvent( + context.Background(), + &user.NewAggregate("userID", "org1").Aggregate, + "username", + "firstname", + "lastname", + "nickname", + "displayname", + language.Afrikaans, + domain.GenderUnspecified, + "email", + false, + ), + ), + expectFilter(), // token lifetime + expectPush( + authrequest.NewCodeExchangedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate), + oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, + "userID", "org1", "sessionID", "clientID", []string{"audience"}, []string{"openid", "offline_access"}, + []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow, "nonce", &language.Afrikaans, + &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + ), + sessionlogout.NewBackChannelLogoutRegisteredEvent(context.Background(), + &sessionlogout.NewAggregate("sessionID", "instanceID").Aggregate, + "V2_oidcSessionID", + "userID", + "clientID", + "backChannelLogoutURI", + ), + oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, + "at_accessTokenID", []string{"openid", "offline_access"}, time.Hour, domain.TokenReasonAuthRequest, nil), + user.NewUserTokenV2AddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, "at_accessTokenID"), + oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, + "rt_refreshTokenID", 7*24*time.Hour, 24*time.Hour), + authrequest.NewSucceededEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate), + ), + ), + idGenerator: mock.NewIDGeneratorExpectIDs(t, "oidcSessionID", "accessTokenID", "refreshTokenID"), + defaultAccessTokenLifetime: time.Hour, + defaultRefreshTokenLifetime: 7 * 24 * time.Hour, + defaultRefreshTokenIdleLifetime: 24 * time.Hour, + keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args{ + + ctx: authz.WithFeatures(authz.WithInstanceID(context.Background(), "instanceID"), feature.Features{ + EnableBackChannelLogout: true, + }), + authRequestID: "V2_authRequestID", + complianceCheck: mockAuthRequestComplianceChecker(nil), + needRefreshToken: true, + backChannelLogoutURI: "backChannelLogoutURI", + }, + res{ + session: &OIDCSession{ + SessionID: "sessionID", + TokenID: "V2_oidcSessionID-at_accessTokenID", + ClientID: "clientID", + UserID: "userID", + Audience: []string{"audience"}, + Expiration: time.Time{}.Add(time.Hour), + Scope: []string{"openid", "offline_access"}, + AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, + AuthTime: testNow, + Nonce: "nonce", + PreferredLanguage: &language.Afrikaans, + UserAgent: &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + Reason: domain.TokenReasonAuthRequest, + RefreshToken: "VjJfb2lkY1Nlc3Npb25JRC1ydF9yZWZyZXNoVG9rZW5JRDp1c2VySUQ", //V2_oidcSessionID-rt_refreshTokenID:userID + }, + state: "state", + }, + }, { "disable user token event", fields{ @@ -708,7 +854,7 @@ func TestCommands_CreateOIDCSessionFromAuthRequest(t *testing.T) { keyAlgorithm: tt.fields.keyAlgorithm, } c.setMilestonesCompletedForTest("instanceID") - gotSession, gotState, err := c.CreateOIDCSessionFromAuthRequest(tt.args.ctx, tt.args.authRequestID, tt.args.complianceCheck, tt.args.needRefreshToken) + gotSession, gotState, err := c.CreateOIDCSessionFromAuthRequest(tt.args.ctx, tt.args.authRequestID, tt.args.complianceCheck, tt.args.needRefreshToken, tt.args.backChannelLogoutURI) require.ErrorIs(t, err, tt.res.err) if gotSession != nil { diff --git a/internal/repository/session/session.go b/internal/repository/session/session.go index 42304aca8e..7aad348841 100644 --- a/internal/repository/session/session.go +++ b/internal/repository/session/session.go @@ -660,7 +660,7 @@ func NewLifetimeSetEvent( type TerminateEvent struct { eventstore.BaseEvent `json:"-"` - TriggerOrigin string `json:"triggerOrigin,omitempty"` + TriggeredAtOrigin string `json:"triggerOrigin,omitempty"` } func (e *TerminateEvent) Payload() interface{} { @@ -671,6 +671,10 @@ func (e *TerminateEvent) UniqueConstraints() []*eventstore.UniqueConstraint { return nil } +func (e *TerminateEvent) TriggerOrigin() string { + return e.TriggeredAtOrigin +} + func NewTerminateEvent( ctx context.Context, aggregate *eventstore.Aggregate, @@ -681,6 +685,7 @@ func NewTerminateEvent( aggregate, TerminateType, ), + TriggeredAtOrigin: http.DomainContext(ctx).Origin(), } }