zitadel/internal/query/access_token.go
Tim Möhlmann a84b259e8c
perf(oidc): nest position clause for session terminated query (#8738)
# Which Problems Are Solved

Optimize the query that checks for terminated sessions in the access
token verifier. The verifier is used in auth middleware, userinfo and
introspection.


# How the Problems Are Solved

The previous implementation built a query for certain events and then
appended a single `PositionAfter` clause. This caused the postgreSQL
planner to use indexes only for the instance ID, aggregate IDs,
aggregate types and event types. Followed by an expensive sequential
scan for the position. This resulting in internal over-fetching of rows
before the final filter was applied.


![Screenshot_20241007_105803](https://github.com/user-attachments/assets/f2d91976-be87-428b-b604-a211399b821c)

Furthermore, the query was searching for events which are not always
applicable. For example, there was always a session ID search and if
there was a user ID, we would also search for a browser fingerprint in
event payload (expensive). Even if those argument string would be empty.

This PR changes:

1. Nest the position query, so that a full `instance_id, aggregate_id,
aggregate_type, event_type, "position"` index can be matched.
2. Redefine the `es_wm` index to include the `position` column.
3. Only search for events for the IDs that actually have a value. Do not
search (noop) if none of session ID, user ID or fingerpint ID are set.

New query plan:


![Screenshot_20241007_110648](https://github.com/user-attachments/assets/c3234c33-1b76-4b33-a4a9-796f69f3d775)


# Additional Changes

- cleanup how we load multi-statement migrations and make that a bit
more reusable.

# Additional Context

- Related to https://github.com/zitadel/zitadel/issues/7639
2024-10-07 12:49:55 +00:00

224 lines
6.8 KiB
Go

package query
import (
"context"
"strings"
"time"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/oidcsession"
"github.com/zitadel/zitadel/internal/repository/session"
"github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
type OIDCSessionAccessTokenReadModel struct {
eventstore.ReadModel
UserID string
SessionID string
ClientID string
Audience []string
Scope []string
AuthMethods []domain.UserAuthMethodType
AuthTime time.Time
Nonce string
State domain.OIDCSessionState
AccessTokenID string
AccessTokenCreation time.Time
AccessTokenExpiration time.Time
PreferredLanguage *language.Tag
UserAgent *domain.UserAgent
Reason domain.TokenReason
Actor *domain.TokenActor
}
func newOIDCSessionAccessTokenReadModel(id string) *OIDCSessionAccessTokenReadModel {
return &OIDCSessionAccessTokenReadModel{
ReadModel: eventstore.ReadModel{
AggregateID: id,
},
}
}
func (wm *OIDCSessionAccessTokenReadModel) Reduce() error {
for _, event := range wm.Events {
switch e := event.(type) {
case *oidcsession.AddedEvent:
wm.reduceAdded(e)
case *oidcsession.AccessTokenAddedEvent:
wm.reduceAccessTokenAdded(e)
case *oidcsession.AccessTokenRevokedEvent,
*oidcsession.RefreshTokenRevokedEvent:
wm.reduceTokenRevoked(event)
}
}
return wm.ReadModel.Reduce()
}
func (wm *OIDCSessionAccessTokenReadModel) Query() *eventstore.SearchQueryBuilder {
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes(oidcsession.AggregateType).
AggregateIDs(wm.AggregateID).
EventTypes(
oidcsession.AddedType,
oidcsession.AccessTokenAddedType,
oidcsession.AccessTokenRevokedType,
oidcsession.RefreshTokenRevokedType,
).
Builder()
}
func (wm *OIDCSessionAccessTokenReadModel) reduceAdded(e *oidcsession.AddedEvent) {
wm.UserID = e.UserID
wm.SessionID = e.SessionID
wm.ClientID = e.ClientID
wm.Audience = e.Audience
wm.Scope = e.Scope
wm.AuthMethods = e.AuthMethods
wm.AuthTime = e.AuthTime
wm.Nonce = e.Nonce
wm.PreferredLanguage = e.PreferredLanguage
wm.UserAgent = e.UserAgent
wm.State = domain.OIDCSessionStateActive
}
func (wm *OIDCSessionAccessTokenReadModel) reduceAccessTokenAdded(e *oidcsession.AccessTokenAddedEvent) {
wm.AccessTokenID = e.ID
wm.AccessTokenCreation = e.CreationDate()
wm.AccessTokenExpiration = e.CreationDate().Add(e.Lifetime)
wm.Reason = e.Reason
wm.Actor = e.Actor
}
func (wm *OIDCSessionAccessTokenReadModel) reduceTokenRevoked(e eventstore.Event) {
wm.AccessTokenID = ""
wm.AccessTokenExpiration = e.CreatedAt()
}
// ActiveAccessTokenByToken will check if the token is active by retrieving the OIDCSession events from the eventstore.
// Refreshed or expired tokens will return an error as well as if the underlying sessions has been terminated.
func (q *Queries) ActiveAccessTokenByToken(ctx context.Context, token string) (model *OIDCSessionAccessTokenReadModel, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
split := strings.Split(token, "-")
if len(split) != 2 {
return nil, zerrors.ThrowUnauthenticated(nil, "QUERY-LJK2W", "Errors.OIDCSession.Token.Invalid")
}
model, err = q.accessTokenByOIDCSessionAndTokenID(ctx, split[0], split[1])
if err != nil {
return nil, err
}
if !model.AccessTokenExpiration.After(time.Now()) {
return nil, zerrors.ThrowUnauthenticated(nil, "QUERY-SAF3rf", "Errors.OIDCSession.Token.Expired")
}
if err = q.checkSessionNotTerminatedAfter(ctx, model.SessionID, model.UserID, model.Position, model.UserAgent.GetFingerprintID()); err != nil {
return nil, err
}
return model, nil
}
func (q *Queries) accessTokenByOIDCSessionAndTokenID(ctx context.Context, oidcSessionID, tokenID string) (model *OIDCSessionAccessTokenReadModel, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
model = newOIDCSessionAccessTokenReadModel(oidcSessionID)
if err = q.eventstore.FilterToQueryReducer(ctx, model); err != nil {
return nil, zerrors.ThrowUnauthenticated(err, "QUERY-ASfe2", "Errors.OIDCSession.Token.Invalid")
}
if model.AccessTokenID != tokenID {
return nil, zerrors.ThrowUnauthenticated(nil, "QUERY-M2u9w", "Errors.OIDCSession.Token.Invalid")
}
return model, nil
}
// checkSessionNotTerminatedAfter checks if a [session.TerminateType] event (or user events leading to a session termination)
// occurred after a certain time and will return an error if so.
func (q *Queries) checkSessionNotTerminatedAfter(ctx context.Context, sessionID, userID string, position float64, fingerprintID string) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if sessionID == "" && userID == "" && fingerprintID == "" {
return nil
}
model := &sessionTerminatedModel{
sessionID: sessionID,
position: position,
userID: userID,
fingerPrintID: fingerprintID,
}
err = q.eventstore.FilterToQueryReducer(ctx, model)
if err != nil {
return zerrors.ThrowUnauthenticated(err, "QUERY-SJ642", "Errors.Internal")
}
if model.terminated {
return zerrors.ThrowUnauthenticated(nil, "QUERY-IJL3H", "Errors.OIDCSession.Token.Invalid")
}
return nil
}
type sessionTerminatedModel struct {
position float64
sessionID string
userID string
fingerPrintID string
events int
terminated bool
}
func (s *sessionTerminatedModel) Reduce() error {
s.terminated = s.events > 0
return nil
}
func (s *sessionTerminatedModel) AppendEvents(events ...eventstore.Event) {
s.events += len(events)
}
func (s *sessionTerminatedModel) Query() *eventstore.SearchQueryBuilder {
builder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent)
if s.sessionID != "" {
builder = builder.AddQuery().
AggregateTypes(session.AggregateType).
AggregateIDs(s.sessionID).
EventTypes(
session.TerminateType,
).
PositionAfter(s.position).
Builder()
}
if s.userID != "" {
builder = builder.AddQuery().
AggregateTypes(user.AggregateType).
AggregateIDs(s.userID).
EventTypes(
user.UserDeactivatedType,
user.UserLockedType,
user.UserRemovedType,
).
PositionAfter(s.position).
Builder()
if s.fingerPrintID != "" {
// for specific logout on v1 sessions from the same user agent
builder = builder.AddQuery().
AggregateTypes(user.AggregateType).
AggregateIDs(s.userID).
EventTypes(
user.HumanSignedOutType,
).
EventData(map[string]interface{}{"userAgentID": s.fingerPrintID}).
PositionAfter(s.position).
Builder()
}
}
return builder
}