perf(oidc): nest position clause for session terminated query (#8738)

# Which Problems Are Solved

Optimize the query that checks for terminated sessions in the access
token verifier. The verifier is used in auth middleware, userinfo and
introspection.


# How the Problems Are Solved

The previous implementation built a query for certain events and then
appended a single `PositionAfter` clause. This caused the postgreSQL
planner to use indexes only for the instance ID, aggregate IDs,
aggregate types and event types. Followed by an expensive sequential
scan for the position. This resulting in internal over-fetching of rows
before the final filter was applied.


![Screenshot_20241007_105803](https://github.com/user-attachments/assets/f2d91976-be87-428b-b604-a211399b821c)

Furthermore, the query was searching for events which are not always
applicable. For example, there was always a session ID search and if
there was a user ID, we would also search for a browser fingerprint in
event payload (expensive). Even if those argument string would be empty.

This PR changes:

1. Nest the position query, so that a full `instance_id, aggregate_id,
aggregate_type, event_type, "position"` index can be matched.
2. Redefine the `es_wm` index to include the `position` column.
3. Only search for events for the IDs that actually have a value. Do not
search (noop) if none of session ID, user ID or fingerpint ID are set.

New query plan:


![Screenshot_20241007_110648](https://github.com/user-attachments/assets/c3234c33-1b76-4b33-a4a9-796f69f3d775)


# Additional Changes

- cleanup how we load multi-statement migrations and make that a bit
more reusable.

# Additional Context

- Related to https://github.com/zitadel/zitadel/issues/7639
This commit is contained in:
Tim Möhlmann 2024-10-07 15:49:55 +03:00 committed by GitHub
parent 2bd3f44094
commit a84b259e8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 148 additions and 54 deletions

View File

@ -25,15 +25,11 @@ type NewEventsTable struct {
} }
func (mig *NewEventsTable) Execute(ctx context.Context, _ eventstore.Event) error { func (mig *NewEventsTable) Execute(ctx context.Context, _ eventstore.Event) error {
migrations, err := newEventsTable.ReadDir("14/" + mig.dbClient.Type())
if err != nil {
return err
}
// if events already exists events2 is created during a setup job // if events already exists events2 is created during a setup job
var count int var count int
err = mig.dbClient.QueryRow( err := mig.dbClient.QueryRowContext(ctx,
func(row *sql.Row) error { func(row *sql.Row) error {
if err = row.Scan(&count); err != nil { if err := row.Scan(&count); err != nil {
return err return err
} }
return row.Err() return row.Err()
@ -43,16 +39,15 @@ func (mig *NewEventsTable) Execute(ctx context.Context, _ eventstore.Event) erro
if err != nil || count == 1 { if err != nil || count == 1 {
return err return err
} }
for _, migration := range migrations {
stmt, err := readStmt(newEventsTable, "14", mig.dbClient.Type(), migration.Name())
if err != nil {
return err
}
stmt = strings.ReplaceAll(stmt, "{{.username}}", mig.dbClient.Username())
logging.WithFields("migration", mig.String(), "file", migration.Name()).Debug("execute statement") statements, err := readStatements(newEventsTable, "14", mig.dbClient.Type())
if err != nil {
_, err = mig.dbClient.ExecContext(ctx, stmt) return err
}
for _, stmt := range statements {
stmt.query = strings.ReplaceAll(stmt.query, "{{.username}}", mig.dbClient.Username())
logging.WithFields("file", stmt.file, "migration", mig.String()).Info("execute statement")
_, err = mig.dbClient.ExecContext(ctx, stmt.query)
if err != nil { if err != nil {
return err return err
} }

View File

@ -21,19 +21,13 @@ type CurrentProjectionState struct {
} }
func (mig *CurrentProjectionState) Execute(ctx context.Context, _ eventstore.Event) error { func (mig *CurrentProjectionState) Execute(ctx context.Context, _ eventstore.Event) error {
migrations, err := currentProjectionState.ReadDir("15/" + mig.dbClient.Type()) statements, err := readStatements(currentProjectionState, "15", mig.dbClient.Type())
if err != nil { if err != nil {
return err return err
} }
for _, migration := range migrations { for _, stmt := range statements {
stmt, err := readStmt(currentProjectionState, "15", mig.dbClient.Type(), migration.Name()) logging.WithFields("file", stmt.file, "migration", mig.String()).Info("execute statement")
if err != nil { _, err = mig.dbClient.ExecContext(ctx, stmt.query)
return err
}
logging.WithFields("file", migration.Name(), "migration", mig.String()).Info("execute statement")
_, err = mig.dbClient.ExecContext(ctx, stmt)
if err != nil { if err != nil {
return err return err
} }

39
cmd/setup/35.go Normal file
View File

@ -0,0 +1,39 @@
package setup
import (
"context"
"embed"
"fmt"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
)
var (
//go:embed 35/*.sql
addPositionToEsWmIndex embed.FS
)
type AddPositionToIndexEsWm struct {
dbClient *database.DB
}
func (mig *AddPositionToIndexEsWm) Execute(ctx context.Context, _ eventstore.Event) error {
statements, err := readStatements(addPositionToEsWmIndex, "35", "")
if err != nil {
return err
}
for _, stmt := range statements {
logging.WithFields("file", stmt.file, "migration", mig.String()).Info("execute statement")
if _, err := mig.dbClient.ExecContext(ctx, stmt.query); err != nil {
return fmt.Errorf("%s %s: %w", mig.String(), stmt.file, err)
}
}
return nil
}
func (mig *AddPositionToIndexEsWm) String() string {
return "35_add_position_to_index_es_wm"
}

View File

@ -0,0 +1,2 @@
CREATE INDEX CONCURRENTLY IF NOT EXISTS es_wm_temp
ON eventstore.events2 (instance_id, aggregate_id, aggregate_type, event_type, "position");

View File

@ -0,0 +1 @@
DROP INDEX IF EXISTS eventstore.es_wm;

View File

@ -0,0 +1 @@
ALTER INDEX eventstore.es_wm_temp RENAME TO es_wm;

View File

@ -121,6 +121,7 @@ type Steps struct {
s32AddAuthSessionID *AddAuthSessionID s32AddAuthSessionID *AddAuthSessionID
s33SMSConfigs3TwilioAddVerifyServiceSid *SMSConfigs3TwilioAddVerifyServiceSid s33SMSConfigs3TwilioAddVerifyServiceSid *SMSConfigs3TwilioAddVerifyServiceSid
s34AddCacheSchema *AddCacheSchema s34AddCacheSchema *AddCacheSchema
s35AddPositionToIndexEsWm *AddPositionToIndexEsWm
} }
func MustNewSteps(v *viper.Viper) *Steps { func MustNewSteps(v *viper.Viper) *Steps {

View File

@ -5,6 +5,7 @@ import (
"embed" "embed"
_ "embed" _ "embed"
"net/http" "net/http"
"path/filepath"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper" "github.com/spf13/viper"
@ -163,6 +164,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
steps.s32AddAuthSessionID = &AddAuthSessionID{dbClient: esPusherDBClient} steps.s32AddAuthSessionID = &AddAuthSessionID{dbClient: esPusherDBClient}
steps.s33SMSConfigs3TwilioAddVerifyServiceSid = &SMSConfigs3TwilioAddVerifyServiceSid{dbClient: esPusherDBClient} steps.s33SMSConfigs3TwilioAddVerifyServiceSid = &SMSConfigs3TwilioAddVerifyServiceSid{dbClient: esPusherDBClient}
steps.s34AddCacheSchema = &AddCacheSchema{dbClient: queryDBClient} steps.s34AddCacheSchema = &AddCacheSchema{dbClient: queryDBClient}
steps.s35AddPositionToIndexEsWm = &AddPositionToIndexEsWm{dbClient: esPusherDBClient}
err = projection.Create(ctx, projectionDBClient, eventstoreClient, config.Projections, nil, nil, nil) err = projection.Create(ctx, projectionDBClient, eventstoreClient, config.Projections, nil, nil, nil)
logging.OnError(err).Fatal("unable to start projections") logging.OnError(err).Fatal("unable to start projections")
@ -206,6 +208,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
steps.s29FillFieldsForProjectGrant, steps.s29FillFieldsForProjectGrant,
steps.s30FillFieldsForOrgDomainVerified, steps.s30FillFieldsForOrgDomainVerified,
steps.s34AddCacheSchema, steps.s34AddCacheSchema,
steps.s35AddPositionToIndexEsWm,
} { } {
mustExecuteMigration(ctx, eventstoreClient, step, "migration failed") mustExecuteMigration(ctx, eventstoreClient, step, "migration failed")
} }
@ -245,11 +248,41 @@ func mustExecuteMigration(ctx context.Context, eventstoreClient *eventstore.Even
logging.WithFields("name", step.String()).OnError(err).Fatal(errorMsg) logging.WithFields("name", step.String()).OnError(err).Fatal(errorMsg)
} }
// readStmt reads a single file from the embedded FS,
// under the folder/typ/filename path.
// Typ describes the database dialect and may be omitted if no
// dialect specific migration is specified.
func readStmt(fs embed.FS, folder, typ, filename string) (string, error) { func readStmt(fs embed.FS, folder, typ, filename string) (string, error) {
stmt, err := fs.ReadFile(folder + "/" + typ + "/" + filename) stmt, err := fs.ReadFile(filepath.Join(folder, typ, filename))
return string(stmt), err return string(stmt), err
} }
type statement struct {
file string
query string
}
// readStatements reads all files from the embedded FS,
// under the folder/type path.
// Typ describes the database dialect and may be omitted if no
// dialect specific migration is specified.
func readStatements(fs embed.FS, folder, typ string) ([]statement, error) {
basePath := filepath.Join(folder, typ)
dir, err := fs.ReadDir(basePath)
if err != nil {
return nil, err
}
statements := make([]statement, len(dir))
for i, file := range dir {
statements[i].file = file.Name()
statements[i].query, err = readStmt(fs, folder, typ, file.Name())
if err != nil {
return nil, err
}
}
return statements, nil
}
func initProjections( func initProjections(
ctx context.Context, ctx context.Context,
eventstoreClient *eventstore.Eventstore, eventstoreClient *eventstore.Eventstore,

View File

@ -156,6 +156,7 @@ func QueryFromBuilder(builder *eventstore.SearchQueryBuilder) (*SearchQuery, err
aggregateIDFilter, aggregateIDFilter,
eventTypeFilter, eventTypeFilter,
eventDataFilter, eventDataFilter,
eventPositionAfterFilter,
} { } {
filter := f(q) filter := f(q)
if filter == nil { if filter == nil {
@ -275,3 +276,10 @@ func eventDataFilter(query *eventstore.SearchQuery) *Filter {
} }
return NewFilter(FieldEventData, query.GetEventData(), OperationJSONContains) return NewFilter(FieldEventData, query.GetEventData(), OperationJSONContains)
} }
func eventPositionAfterFilter(query *eventstore.SearchQuery) *Filter {
if pos := query.GetPositionAfter(); pos != 0 {
return NewFilter(FieldPosition, pos, OperationGreater)
}
return nil
}

View File

@ -107,6 +107,7 @@ type SearchQuery struct {
aggregateIDs []string aggregateIDs []string
eventTypes []EventType eventTypes []EventType
eventData map[string]interface{} eventData map[string]interface{}
positionAfter float64
} }
func (q SearchQuery) GetAggregateTypes() []AggregateType { func (q SearchQuery) GetAggregateTypes() []AggregateType {
@ -125,6 +126,10 @@ func (q SearchQuery) GetEventData() map[string]interface{} {
return q.eventData return q.eventData
} }
func (q SearchQuery) GetPositionAfter() float64 {
return q.positionAfter
}
// Columns defines which fields of the event are needed for the query // Columns defines which fields of the event are needed for the query
type Columns int8 type Columns int8
@ -344,6 +349,11 @@ func (query *SearchQuery) EventData(data map[string]interface{}) *SearchQuery {
return query return query
} }
func (query *SearchQuery) PositionAfter(position float64) *SearchQuery {
query.positionAfter = position
return query
}
// Builder returns the SearchQueryBuilder of the sub query // Builder returns the SearchQueryBuilder of the sub query
func (query *SearchQuery) Builder() *SearchQueryBuilder { func (query *SearchQuery) Builder() *SearchQueryBuilder {
return query.builder return query.builder

View File

@ -144,6 +144,9 @@ func (q *Queries) checkSessionNotTerminatedAfter(ctx context.Context, sessionID,
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
if sessionID == "" && userID == "" && fingerprintID == "" {
return nil
}
model := &sessionTerminatedModel{ model := &sessionTerminatedModel{
sessionID: sessionID, sessionID: sessionID,
position: position, position: position,
@ -181,33 +184,40 @@ func (s *sessionTerminatedModel) AppendEvents(events ...eventstore.Event) {
} }
func (s *sessionTerminatedModel) Query() *eventstore.SearchQueryBuilder { func (s *sessionTerminatedModel) Query() *eventstore.SearchQueryBuilder {
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). builder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent)
PositionAfter(s.position). if s.sessionID != "" {
AddQuery(). builder = builder.AddQuery().
AggregateTypes(session.AggregateType). AggregateTypes(session.AggregateType).
AggregateIDs(s.sessionID). AggregateIDs(s.sessionID).
EventTypes( EventTypes(
session.TerminateType, session.TerminateType,
). ).
Builder() PositionAfter(s.position).
if s.userID == "" { Builder()
return query
} }
return query. if s.userID != "" {
AddQuery(). builder = builder.AddQuery().
AggregateTypes(user.AggregateType). AggregateTypes(user.AggregateType).
AggregateIDs(s.userID). AggregateIDs(s.userID).
EventTypes( EventTypes(
user.UserDeactivatedType, user.UserDeactivatedType,
user.UserLockedType, user.UserLockedType,
user.UserRemovedType, user.UserRemovedType,
). ).
Or(). // for specific logout on v1 sessions from the same user agent PositionAfter(s.position).
AggregateTypes(user.AggregateType). Builder()
AggregateIDs(s.userID). if s.fingerPrintID != "" {
EventTypes( // for specific logout on v1 sessions from the same user agent
user.HumanSignedOutType, builder = builder.AddQuery().
). AggregateTypes(user.AggregateType).
EventData(map[string]interface{}{"userAgentID": s.fingerPrintID}). AggregateIDs(s.userID).
Builder() EventTypes(
user.HumanSignedOutType,
).
EventData(map[string]interface{}{"userAgentID": s.fingerPrintID}).
PositionAfter(s.position).
Builder()
}
}
return builder
} }