From 68af4f59c94783e224938c5897d66bc9af28f347 Mon Sep 17 00:00:00 2001 From: Livio Spring Date: Wed, 28 Feb 2024 10:30:05 +0100 Subject: [PATCH] fix(api): handle user disabling events correctly in session API (#7380) This PR makes sure that user disabling events (deactivate, locked, ...) are correctly checked for sessions. --- internal/api/grpc/session/v2/session.go | 3 + .../session/v2/session_integration_test.go | 40 ++++++++- .../api/oidc/auth_request_integration_test.go | 30 +++---- internal/api/oidc/client_integration_test.go | 6 +- internal/api/oidc/oidc_integration_test.go | 87 +++++++++++++++++-- .../eventsourcing/eventstore/auth_request.go | 2 +- internal/domain/user.go | 2 +- internal/query/access_token.go | 69 ++++++++++++--- 8 files changed, 193 insertions(+), 46 deletions(-) diff --git a/internal/api/grpc/session/v2/session.go b/internal/api/grpc/session/v2/session.go index 68d98e57ee..428cd94517 100644 --- a/internal/api/grpc/session/v2/session.go +++ b/internal/api/grpc/session/v2/session.go @@ -351,6 +351,9 @@ func (s *Server) checksToCommand(ctx context.Context, checks *session.Checks) ([ if err != nil { return nil, err } + if !user.State.IsEnabled() { + return nil, zerrors.ThrowPreconditionFailed(nil, "SESSION-Gj4ko", "Errors.User.NotActive") + } sessionChecks = append(sessionChecks, command.CheckUser(user.ID, user.ResourceOwner)) } if password := checks.GetPassword(); password != nil { diff --git a/internal/api/grpc/session/v2/session_integration_test.go b/internal/api/grpc/session/v2/session_integration_test.go index d1ad5b1a7f..5cdc350f88 100644 --- a/internal/api/grpc/session/v2/session_integration_test.go +++ b/internal/api/grpc/session/v2/session_integration_test.go @@ -24,10 +24,12 @@ import ( ) var ( - CTX context.Context - Tester *integration.Tester - Client session.SessionServiceClient - User *user.AddHumanUserResponse + CTX context.Context + Tester *integration.Tester + Client session.SessionServiceClient + User *user.AddHumanUserResponse + DeactivatedUser *user.AddHumanUserResponse + LockedUser *user.AddHumanUserResponse ) func TestMain(m *testing.M) { @@ -51,6 +53,10 @@ func TestMain(m *testing.M) { }) Tester.SetUserPassword(CTX, User.GetUserId(), integration.UserPassword) Tester.RegisterUserPasskey(CTX, User.GetUserId()) + DeactivatedUser = Tester.CreateHumanUser(CTX) + Tester.Client.UserV2.DeactivateUser(CTX, &user.DeactivateUserRequest{UserId: DeactivatedUser.GetUserId()}) + LockedUser = Tester.CreateHumanUser(CTX) + Tester.Client.UserV2.LockUser(CTX, &user.LockUserRequest{UserId: LockedUser.GetUserId()}) return m.Run() }()) } @@ -229,6 +235,32 @@ func TestServer_CreateSession(t *testing.T) { }, wantFactors: []wantFactor{wantUserFactor}, }, + { + name: "deactivated user", + req: &session.CreateSessionRequest{ + Checks: &session.Checks{ + User: &session.CheckUser{ + Search: &session.CheckUser_UserId{ + UserId: DeactivatedUser.GetUserId(), + }, + }, + }, + }, + wantErr: true, + }, + { + name: "locked user", + req: &session.CreateSessionRequest{ + Checks: &session.Checks{ + User: &session.CheckUser{ + Search: &session.CheckUser_UserId{ + UserId: LockedUser.GetUserId(), + }, + }, + }, + }, + wantErr: true, + }, { name: "password without user error", req: &session.CreateSessionRequest{ diff --git a/internal/api/oidc/auth_request_integration_test.go b/internal/api/oidc/auth_request_integration_test.go index 44a4b82e6b..75306fbfb7 100644 --- a/internal/api/oidc/auth_request_integration_test.go +++ b/internal/api/oidc/auth_request_integration_test.go @@ -54,7 +54,7 @@ func TestOPStorage_CreateAccessToken_code(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, false) - assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) // callback on a succeeded request must fail linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ @@ -108,7 +108,7 @@ func TestOPStorage_CreateAccessToken_implicit(t *testing.T) { require.NoError(t, err) claims, err := rp.VerifyTokens[*oidc.IDTokenClaims](context.Background(), accessToken, idToken, provider.IDTokenVerifier()) require.NoError(t, err) - assertIDTokenClaims(t, claims, armPasskey, startTime, changeTime) + assertIDTokenClaims(t, claims, User.GetUserId(), armPasskey, startTime, changeTime) // callback on a succeeded request must fail linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ @@ -143,7 +143,7 @@ func TestOPStorage_CreateAccessAndRefreshTokens_code(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) } func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) { @@ -168,14 +168,14 @@ func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) // test actual refresh grant newTokens, err := refreshTokens(t, clientID, tokens.RefreshToken) require.NoError(t, err) assertTokens(t, newTokens, true) // auth time must still be the initial - assertIDTokenClaims(t, newTokens.IDTokenClaims, armPasskey, startTime, changeTime) + assertIDTokenClaims(t, newTokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) // refresh with an old refresh_token must fail _, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "") @@ -204,7 +204,7 @@ func TestOPStorage_RevokeToken_access_token(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) // revoke access token err = rp.RevokeToken(CTX, provider, tokens.AccessToken, "access_token") @@ -247,7 +247,7 @@ func TestOPStorage_RevokeToken_access_token_invalid_token_hint_type(t *testing.T tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) // revoke access token err = rp.RevokeToken(CTX, provider, tokens.AccessToken, "refresh_token") @@ -284,7 +284,7 @@ func TestOPStorage_RevokeToken_refresh_token(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) // revoke refresh token -> invalidates also access token err = rp.RevokeToken(CTX, provider, tokens.RefreshToken, "refresh_token") @@ -327,7 +327,7 @@ func TestOPStorage_RevokeToken_refresh_token_invalid_token_type_hint(t *testing. tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) // revoke refresh token even with a wrong hint err = rp.RevokeToken(CTX, provider, tokens.RefreshToken, "access_token") @@ -362,7 +362,7 @@ func TestOPStorage_RevokeToken_invalid_client(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) // simulate second client (not part of the audience) trying to revoke the token otherClientID := createClient(t) @@ -394,7 +394,7 @@ func TestOPStorage_TerminateSession(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, false) - assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) // userinfo must not fail _, err = rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider) @@ -431,7 +431,7 @@ func TestOPStorage_TerminateSession_refresh_grant(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) // userinfo must not fail _, err = rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider) @@ -475,7 +475,7 @@ func TestOPStorage_TerminateSession_empty_id_token_hint(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, false) - assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) postLogoutRedirect, err := rp.EndSession(CTX, provider, "", logoutRedirectURI, "state") require.NoError(t, err) @@ -530,8 +530,8 @@ func assertTokens(t *testing.T, tokens *oidc.Tokens[*oidc.IDTokenClaims], requir } } -func assertIDTokenClaims(t *testing.T, claims *oidc.IDTokenClaims, arm []string, sessionStart, sessionChange time.Time) { - assert.Equal(t, User.GetUserId(), claims.Subject) +func assertIDTokenClaims(t *testing.T, claims *oidc.IDTokenClaims, userID string, arm []string, sessionStart, sessionChange time.Time) { + assert.Equal(t, userID, claims.Subject) assert.Equal(t, arm, claims.AuthenticationMethodsReferences) assertOIDCTimeRange(t, claims.AuthTime, sessionStart, sessionChange) } diff --git a/internal/api/oidc/client_integration_test.go b/internal/api/oidc/client_integration_test.go index 2605c812ce..3faa930c45 100644 --- a/internal/api/oidc/client_integration_test.go +++ b/internal/api/oidc/client_integration_test.go @@ -45,7 +45,7 @@ func TestOPStorage_SetUserinfoFromToken(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) // test actual userinfo provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI) @@ -154,7 +154,7 @@ func TestServer_Introspect(t *testing.T) { tokens, err := exchangeTokens(t, app.GetClientId(), code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) // test actual introspection introspection, err := rs.Introspect[*oidc.IntrospectionResponse](context.Background(), resourceServer, tokens.AccessToken) @@ -360,7 +360,7 @@ func TestServer_VerifyClient(t *testing.T) { } require.NoError(t, err) assertTokens(t, tokens, false) - assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) }) } } diff --git a/internal/api/oidc/oidc_integration_test.go b/internal/api/oidc/oidc_integration_test.go index 94b3d937b5..a5352a3b6e 100644 --- a/internal/api/oidc/oidc_integration_test.go +++ b/internal/api/oidc/oidc_integration_test.go @@ -39,7 +39,7 @@ const ( func TestMain(m *testing.M) { os.Exit(func() int { - ctx, errCtx, cancel := integration.Contexts(5 * time.Minute) + ctx, errCtx, cancel := integration.Contexts(10 * time.Minute) defer cancel() Tester = integration.NewTester(ctx) @@ -74,7 +74,7 @@ func Test_ZITADEL_API_missing_audience_scope(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, false) - assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) @@ -136,7 +136,7 @@ func Test_ZITADEL_API_missing_mfa(t *testing.T) { code := assertCodeResponse(t, linkResp.GetCallbackUrl()) tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) - assertIDTokenClaims(t, tokens.IDTokenClaims, armPassword, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPassword, startTime, changeTime) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) @@ -165,7 +165,7 @@ func Test_ZITADEL_API_success(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, false) - assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) @@ -199,7 +199,7 @@ func Test_ZITADEL_API_glob_redirects(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, false) - assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) @@ -228,7 +228,7 @@ func Test_ZITADEL_API_inactive_access_token(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) // make sure token works ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) @@ -270,7 +270,7 @@ func Test_ZITADEL_API_terminated_session(t *testing.T) { tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) - assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) // make sure token works ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) @@ -278,7 +278,7 @@ func Test_ZITADEL_API_terminated_session(t *testing.T) { require.NoError(t, err) require.Equal(t, User.GetUserId(), myUserResp.GetUser().GetId()) - // refresh token + // end session postLogoutRedirect, err := rp.EndSession(CTX, provider, tokens.IDToken, logoutRedirectURI, "state") require.NoError(t, err) assert.Equal(t, logoutRedirectURI+"?state=state", postLogoutRedirect.String()) @@ -290,6 +290,77 @@ func Test_ZITADEL_API_terminated_session(t *testing.T) { require.Nil(t, myUserResp) } +func Test_ZITADEL_API_terminated_session_user_disabled(t *testing.T) { + clientID := createClient(t) + tests := []struct { + name string + disable func(userID string) error + }{ + { + name: "deactivated", + disable: func(userID string) error { + _, err := Tester.Client.UserV2.DeactivateUser(CTX, &user.DeactivateUserRequest{UserId: userID}) + return err + }, + }, + { + name: "locked", + disable: func(userID string) error { + _, err := Tester.Client.UserV2.LockUser(CTX, &user.LockUserRequest{UserId: userID}) + return err + }, + }, + { + name: "deleted", + disable: func(userID string) error { + _, err := Tester.Client.UserV2.DeleteUser(CTX, &user.DeleteUserRequest{UserId: userID}) + return err + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + disabledUser := Tester.CreateHumanUser(CTX) + Tester.SetUserPassword(CTX, disabledUser.GetUserId(), integration.UserPassword) + authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess, zitadelAudienceScope) + sessionID, sessionToken, startTime, changeTime := Tester.CreatePasswordSession(t, CTXLOGIN, disabledUser.GetUserId(), integration.UserPassword) + linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ + AuthRequestId: authRequestID, + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionID, + SessionToken: sessionToken, + }, + }, + }) + require.NoError(t, err) + + // code exchange + code := assertCodeResponse(t, linkResp.GetCallbackUrl()) + tokens, err := exchangeTokens(t, clientID, code, redirectURI) + require.NoError(t, err) + assertTokens(t, tokens, true) + assertIDTokenClaims(t, tokens.IDTokenClaims, disabledUser.GetUserId(), armPassword, startTime, changeTime) + + // make sure token works + ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) + myUserResp, err := Tester.Client.Auth.GetMyUser(ctx, &auth.GetMyUserRequest{}) + require.NoError(t, err) + require.Equal(t, disabledUser.GetUserId(), myUserResp.GetUser().GetId()) + + // deactivate user + err = tt.disable(disabledUser.GetUserId()) + require.NoError(t, err) + + // use token from deactivated user + ctx = metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) + myUserResp, err = Tester.Client.Auth.GetMyUser(ctx, &auth.GetMyUserRequest{}) + require.Error(t, err) + require.Nil(t, myUserResp) + }) + } +} + func createClient(t testing.TB) string { return createClientWithOpts(t, clientOpts{ redirectURI: redirectURI, diff --git a/internal/auth/repository/eventsourcing/eventstore/auth_request.go b/internal/auth/repository/eventsourcing/eventstore/auth_request.go index 30b7f332a2..7dec242792 100644 --- a/internal/auth/repository/eventsourcing/eventstore/auth_request.go +++ b/internal/auth/repository/eventsourcing/eventstore/auth_request.go @@ -743,7 +743,7 @@ func (repo *AuthRequestRepo) checkLoginName(ctx context.Context, request *domain return err } // if there's an active (human) user, let's use it - if user != nil && !user.HumanView.IsZero() && domain.UserState(user.State).NotDisabled() { + if user != nil && !user.HumanView.IsZero() && domain.UserState(user.State).IsEnabled() { request.SetUserInfo(user.ID, loginNameInput, user.PreferredLoginName, "", "", user.ResourceOwner) return nil } diff --git a/internal/domain/user.go b/internal/domain/user.go index 7450d06417..24427b2d57 100644 --- a/internal/domain/user.go +++ b/internal/domain/user.go @@ -18,7 +18,7 @@ func (s UserState) Exists() bool { return s != UserStateUnspecified && s != UserStateDeleted } -func (s UserState) NotDisabled() bool { +func (s UserState) IsEnabled() bool { return s == UserStateActive || s == UserStateInitial } diff --git a/internal/query/access_token.go b/internal/query/access_token.go index 379b561bf5..7ed46d85e1 100644 --- a/internal/query/access_token.go +++ b/internal/query/access_token.go @@ -9,6 +9,7 @@ import ( "github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/repository/oidcsession" "github.com/zitadel/zitadel/internal/repository/session" + "github.com/zitadel/zitadel/internal/repository/user" "github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/zerrors" ) @@ -107,7 +108,7 @@ func (q *Queries) ActiveAccessTokenByToken(ctx context.Context, token string) (m if !model.AccessTokenExpiration.After(time.Now()) { return nil, zerrors.ThrowPermissionDenied(nil, "QUERY-SAF3rf", "Errors.OIDCSession.Token.Expired") } - if err = q.checkSessionNotTerminatedAfter(ctx, model.SessionID, model.AccessTokenCreation); err != nil { + if err = q.checkSessionNotTerminatedAfter(ctx, model.SessionID, model.UserID, model.AccessTokenCreation); err != nil { return nil, err } return model, nil @@ -129,26 +130,66 @@ func (q *Queries) accessTokenByOIDCSessionAndTokenID(ctx context.Context, oidcSe // checkSessionNotTerminatedAfter checks if a [session.TerminateType] event occurred after a certain time // and will return an error if so. -func (q *Queries) checkSessionNotTerminatedAfter(ctx context.Context, sessionID string, creation time.Time) (err error) { +func (q *Queries) checkSessionNotTerminatedAfter(ctx context.Context, sessionID, userID string, creation time.Time) (err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - events, err := q.eventstore.Filter(ctx, eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). - AwaitOpenTransactions(). - AllowTimeTravel(). - CreationDateAfter(creation). - AddQuery(). - AggregateTypes(session.AggregateType). - AggregateIDs(sessionID). - EventTypes( - session.TerminateType, - ). - Builder()) + model := &sessionTerminatedModel{ + sessionID: sessionID, + creation: creation, + userID: userID, + } + err = q.eventstore.FilterToQueryReducer(ctx, model) if err != nil { return zerrors.ThrowPermissionDenied(err, "QUERY-SJ642", "Errors.Internal") } - if len(events) > 0 { + + if model.terminated { return zerrors.ThrowPermissionDenied(nil, "QUERY-IJL3H", "Errors.OIDCSession.Token.Invalid") } return nil } + +type sessionTerminatedModel struct { + creation time.Time + sessionID string + userID string + + events int + terminated bool +} + +func (s *sessionTerminatedModel) Reduce() error { + s.terminated = s.events > 0 + return nil +} + +func (s *sessionTerminatedModel) AppendEvents(events ...eventstore.Event) { + s.events += len(events) +} + +func (s *sessionTerminatedModel) Query() *eventstore.SearchQueryBuilder { + query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). + AwaitOpenTransactions(). + CreationDateAfter(s.creation). + AddQuery(). + AggregateTypes(session.AggregateType). + AggregateIDs(s.sessionID). + EventTypes( + session.TerminateType, + ). + Builder() + if s.userID == "" { + return query + } + return query. + AddQuery(). + AggregateTypes(user.AggregateType). + AggregateIDs(s.userID). + EventTypes( + user.UserDeactivatedType, + user.UserLockedType, + user.UserRemovedType, + ). + Builder() +}