mirror of
https://github.com/zitadel/zitadel.git
synced 2024-12-13 03:24:26 +00:00
a84b259e8c
# Which Problems Are Solved Optimize the query that checks for terminated sessions in the access token verifier. The verifier is used in auth middleware, userinfo and introspection. # How the Problems Are Solved The previous implementation built a query for certain events and then appended a single `PositionAfter` clause. This caused the postgreSQL planner to use indexes only for the instance ID, aggregate IDs, aggregate types and event types. Followed by an expensive sequential scan for the position. This resulting in internal over-fetching of rows before the final filter was applied. ![Screenshot_20241007_105803](https://github.com/user-attachments/assets/f2d91976-be87-428b-b604-a211399b821c) Furthermore, the query was searching for events which are not always applicable. For example, there was always a session ID search and if there was a user ID, we would also search for a browser fingerprint in event payload (expensive). Even if those argument string would be empty. This PR changes: 1. Nest the position query, so that a full `instance_id, aggregate_id, aggregate_type, event_type, "position"` index can be matched. 2. Redefine the `es_wm` index to include the `position` column. 3. Only search for events for the IDs that actually have a value. Do not search (noop) if none of session ID, user ID or fingerpint ID are set. New query plan: ![Screenshot_20241007_110648](https://github.com/user-attachments/assets/c3234c33-1b76-4b33-a4a9-796f69f3d775) # Additional Changes - cleanup how we load multi-statement migrations and make that a bit more reusable. # Additional Context - Related to https://github.com/zitadel/zitadel/issues/7639
224 lines
6.8 KiB
Go
224 lines
6.8 KiB
Go
package query
|
|
|
|
import (
|
|
"context"
|
|
"strings"
|
|
"time"
|
|
|
|
"golang.org/x/text/language"
|
|
|
|
"github.com/zitadel/zitadel/internal/domain"
|
|
"github.com/zitadel/zitadel/internal/eventstore"
|
|
"github.com/zitadel/zitadel/internal/repository/oidcsession"
|
|
"github.com/zitadel/zitadel/internal/repository/session"
|
|
"github.com/zitadel/zitadel/internal/repository/user"
|
|
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
|
"github.com/zitadel/zitadel/internal/zerrors"
|
|
)
|
|
|
|
type OIDCSessionAccessTokenReadModel struct {
|
|
eventstore.ReadModel
|
|
|
|
UserID string
|
|
SessionID string
|
|
ClientID string
|
|
Audience []string
|
|
Scope []string
|
|
AuthMethods []domain.UserAuthMethodType
|
|
AuthTime time.Time
|
|
Nonce string
|
|
State domain.OIDCSessionState
|
|
AccessTokenID string
|
|
AccessTokenCreation time.Time
|
|
AccessTokenExpiration time.Time
|
|
PreferredLanguage *language.Tag
|
|
UserAgent *domain.UserAgent
|
|
Reason domain.TokenReason
|
|
Actor *domain.TokenActor
|
|
}
|
|
|
|
func newOIDCSessionAccessTokenReadModel(id string) *OIDCSessionAccessTokenReadModel {
|
|
return &OIDCSessionAccessTokenReadModel{
|
|
ReadModel: eventstore.ReadModel{
|
|
AggregateID: id,
|
|
},
|
|
}
|
|
}
|
|
|
|
func (wm *OIDCSessionAccessTokenReadModel) Reduce() error {
|
|
for _, event := range wm.Events {
|
|
switch e := event.(type) {
|
|
case *oidcsession.AddedEvent:
|
|
wm.reduceAdded(e)
|
|
case *oidcsession.AccessTokenAddedEvent:
|
|
wm.reduceAccessTokenAdded(e)
|
|
case *oidcsession.AccessTokenRevokedEvent,
|
|
*oidcsession.RefreshTokenRevokedEvent:
|
|
wm.reduceTokenRevoked(event)
|
|
}
|
|
}
|
|
return wm.ReadModel.Reduce()
|
|
}
|
|
|
|
func (wm *OIDCSessionAccessTokenReadModel) Query() *eventstore.SearchQueryBuilder {
|
|
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
|
AddQuery().
|
|
AggregateTypes(oidcsession.AggregateType).
|
|
AggregateIDs(wm.AggregateID).
|
|
EventTypes(
|
|
oidcsession.AddedType,
|
|
oidcsession.AccessTokenAddedType,
|
|
oidcsession.AccessTokenRevokedType,
|
|
oidcsession.RefreshTokenRevokedType,
|
|
).
|
|
Builder()
|
|
}
|
|
|
|
func (wm *OIDCSessionAccessTokenReadModel) reduceAdded(e *oidcsession.AddedEvent) {
|
|
wm.UserID = e.UserID
|
|
wm.SessionID = e.SessionID
|
|
wm.ClientID = e.ClientID
|
|
wm.Audience = e.Audience
|
|
wm.Scope = e.Scope
|
|
wm.AuthMethods = e.AuthMethods
|
|
wm.AuthTime = e.AuthTime
|
|
wm.Nonce = e.Nonce
|
|
wm.PreferredLanguage = e.PreferredLanguage
|
|
wm.UserAgent = e.UserAgent
|
|
wm.State = domain.OIDCSessionStateActive
|
|
}
|
|
|
|
func (wm *OIDCSessionAccessTokenReadModel) reduceAccessTokenAdded(e *oidcsession.AccessTokenAddedEvent) {
|
|
wm.AccessTokenID = e.ID
|
|
wm.AccessTokenCreation = e.CreationDate()
|
|
wm.AccessTokenExpiration = e.CreationDate().Add(e.Lifetime)
|
|
wm.Reason = e.Reason
|
|
wm.Actor = e.Actor
|
|
}
|
|
|
|
func (wm *OIDCSessionAccessTokenReadModel) reduceTokenRevoked(e eventstore.Event) {
|
|
wm.AccessTokenID = ""
|
|
wm.AccessTokenExpiration = e.CreatedAt()
|
|
}
|
|
|
|
// ActiveAccessTokenByToken will check if the token is active by retrieving the OIDCSession events from the eventstore.
|
|
// Refreshed or expired tokens will return an error as well as if the underlying sessions has been terminated.
|
|
func (q *Queries) ActiveAccessTokenByToken(ctx context.Context, token string) (model *OIDCSessionAccessTokenReadModel, err error) {
|
|
ctx, span := tracing.NewSpan(ctx)
|
|
defer func() { span.EndWithError(err) }()
|
|
|
|
split := strings.Split(token, "-")
|
|
if len(split) != 2 {
|
|
return nil, zerrors.ThrowUnauthenticated(nil, "QUERY-LJK2W", "Errors.OIDCSession.Token.Invalid")
|
|
}
|
|
model, err = q.accessTokenByOIDCSessionAndTokenID(ctx, split[0], split[1])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !model.AccessTokenExpiration.After(time.Now()) {
|
|
return nil, zerrors.ThrowUnauthenticated(nil, "QUERY-SAF3rf", "Errors.OIDCSession.Token.Expired")
|
|
}
|
|
if err = q.checkSessionNotTerminatedAfter(ctx, model.SessionID, model.UserID, model.Position, model.UserAgent.GetFingerprintID()); err != nil {
|
|
return nil, err
|
|
}
|
|
return model, nil
|
|
}
|
|
|
|
func (q *Queries) accessTokenByOIDCSessionAndTokenID(ctx context.Context, oidcSessionID, tokenID string) (model *OIDCSessionAccessTokenReadModel, err error) {
|
|
ctx, span := tracing.NewSpan(ctx)
|
|
defer func() { span.EndWithError(err) }()
|
|
|
|
model = newOIDCSessionAccessTokenReadModel(oidcSessionID)
|
|
if err = q.eventstore.FilterToQueryReducer(ctx, model); err != nil {
|
|
return nil, zerrors.ThrowUnauthenticated(err, "QUERY-ASfe2", "Errors.OIDCSession.Token.Invalid")
|
|
}
|
|
if model.AccessTokenID != tokenID {
|
|
return nil, zerrors.ThrowUnauthenticated(nil, "QUERY-M2u9w", "Errors.OIDCSession.Token.Invalid")
|
|
}
|
|
return model, nil
|
|
}
|
|
|
|
// checkSessionNotTerminatedAfter checks if a [session.TerminateType] event (or user events leading to a session termination)
|
|
// occurred after a certain time and will return an error if so.
|
|
func (q *Queries) checkSessionNotTerminatedAfter(ctx context.Context, sessionID, userID string, position float64, fingerprintID string) (err error) {
|
|
ctx, span := tracing.NewSpan(ctx)
|
|
defer func() { span.EndWithError(err) }()
|
|
|
|
if sessionID == "" && userID == "" && fingerprintID == "" {
|
|
return nil
|
|
}
|
|
model := &sessionTerminatedModel{
|
|
sessionID: sessionID,
|
|
position: position,
|
|
userID: userID,
|
|
fingerPrintID: fingerprintID,
|
|
}
|
|
err = q.eventstore.FilterToQueryReducer(ctx, model)
|
|
if err != nil {
|
|
return zerrors.ThrowUnauthenticated(err, "QUERY-SJ642", "Errors.Internal")
|
|
}
|
|
|
|
if model.terminated {
|
|
return zerrors.ThrowUnauthenticated(nil, "QUERY-IJL3H", "Errors.OIDCSession.Token.Invalid")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type sessionTerminatedModel struct {
|
|
position float64
|
|
sessionID string
|
|
userID string
|
|
fingerPrintID string
|
|
|
|
events int
|
|
terminated bool
|
|
}
|
|
|
|
func (s *sessionTerminatedModel) Reduce() error {
|
|
s.terminated = s.events > 0
|
|
return nil
|
|
}
|
|
|
|
func (s *sessionTerminatedModel) AppendEvents(events ...eventstore.Event) {
|
|
s.events += len(events)
|
|
}
|
|
|
|
func (s *sessionTerminatedModel) Query() *eventstore.SearchQueryBuilder {
|
|
builder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent)
|
|
if s.sessionID != "" {
|
|
builder = builder.AddQuery().
|
|
AggregateTypes(session.AggregateType).
|
|
AggregateIDs(s.sessionID).
|
|
EventTypes(
|
|
session.TerminateType,
|
|
).
|
|
PositionAfter(s.position).
|
|
Builder()
|
|
}
|
|
if s.userID != "" {
|
|
builder = builder.AddQuery().
|
|
AggregateTypes(user.AggregateType).
|
|
AggregateIDs(s.userID).
|
|
EventTypes(
|
|
user.UserDeactivatedType,
|
|
user.UserLockedType,
|
|
user.UserRemovedType,
|
|
).
|
|
PositionAfter(s.position).
|
|
Builder()
|
|
if s.fingerPrintID != "" {
|
|
// for specific logout on v1 sessions from the same user agent
|
|
builder = builder.AddQuery().
|
|
AggregateTypes(user.AggregateType).
|
|
AggregateIDs(s.userID).
|
|
EventTypes(
|
|
user.HumanSignedOutType,
|
|
).
|
|
EventData(map[string]interface{}{"userAgentID": s.fingerPrintID}).
|
|
PositionAfter(s.position).
|
|
Builder()
|
|
}
|
|
}
|
|
return builder
|
|
}
|