fix(api): handle user disabling events correctly in session API (#7380)

This PR makes sure that user disabling events (deactivate, locked, ...) are correctly checked for sessions.
This commit is contained in:
Livio Spring
2024-02-28 10:30:05 +01:00
committed by GitHub
parent 26d1563643
commit 68af4f59c9
8 changed files with 193 additions and 46 deletions

View File

@@ -9,6 +9,7 @@ import (
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/oidcsession"
"github.com/zitadel/zitadel/internal/repository/session"
"github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
@@ -107,7 +108,7 @@ func (q *Queries) ActiveAccessTokenByToken(ctx context.Context, token string) (m
if !model.AccessTokenExpiration.After(time.Now()) {
return nil, zerrors.ThrowPermissionDenied(nil, "QUERY-SAF3rf", "Errors.OIDCSession.Token.Expired")
}
if err = q.checkSessionNotTerminatedAfter(ctx, model.SessionID, model.AccessTokenCreation); err != nil {
if err = q.checkSessionNotTerminatedAfter(ctx, model.SessionID, model.UserID, model.AccessTokenCreation); err != nil {
return nil, err
}
return model, nil
@@ -129,26 +130,66 @@ func (q *Queries) accessTokenByOIDCSessionAndTokenID(ctx context.Context, oidcSe
// checkSessionNotTerminatedAfter checks if a [session.TerminateType] event occurred after a certain time
// and will return an error if so.
func (q *Queries) checkSessionNotTerminatedAfter(ctx context.Context, sessionID string, creation time.Time) (err error) {
func (q *Queries) checkSessionNotTerminatedAfter(ctx context.Context, sessionID, userID string, creation time.Time) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
events, err := q.eventstore.Filter(ctx, eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AwaitOpenTransactions().
AllowTimeTravel().
CreationDateAfter(creation).
AddQuery().
AggregateTypes(session.AggregateType).
AggregateIDs(sessionID).
EventTypes(
session.TerminateType,
).
Builder())
model := &sessionTerminatedModel{
sessionID: sessionID,
creation: creation,
userID: userID,
}
err = q.eventstore.FilterToQueryReducer(ctx, model)
if err != nil {
return zerrors.ThrowPermissionDenied(err, "QUERY-SJ642", "Errors.Internal")
}
if len(events) > 0 {
if model.terminated {
return zerrors.ThrowPermissionDenied(nil, "QUERY-IJL3H", "Errors.OIDCSession.Token.Invalid")
}
return nil
}
type sessionTerminatedModel struct {
creation time.Time
sessionID string
userID string
events int
terminated bool
}
func (s *sessionTerminatedModel) Reduce() error {
s.terminated = s.events > 0
return nil
}
func (s *sessionTerminatedModel) AppendEvents(events ...eventstore.Event) {
s.events += len(events)
}
func (s *sessionTerminatedModel) Query() *eventstore.SearchQueryBuilder {
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AwaitOpenTransactions().
CreationDateAfter(s.creation).
AddQuery().
AggregateTypes(session.AggregateType).
AggregateIDs(s.sessionID).
EventTypes(
session.TerminateType,
).
Builder()
if s.userID == "" {
return query
}
return query.
AddQuery().
AggregateTypes(user.AggregateType).
AggregateIDs(s.userID).
EventTypes(
user.UserDeactivatedType,
user.UserLockedType,
user.UserRemovedType,
).
Builder()
}