perf(oidc): nest position clause for session terminated query (#8738)

# Which Problems Are Solved

Optimize the query that checks for terminated sessions in the access
token verifier. The verifier is used in auth middleware, userinfo and
introspection.


# How the Problems Are Solved

The previous implementation built a query for certain events and then
appended a single `PositionAfter` clause. This caused the postgreSQL
planner to use indexes only for the instance ID, aggregate IDs,
aggregate types and event types. Followed by an expensive sequential
scan for the position. This resulting in internal over-fetching of rows
before the final filter was applied.


![Screenshot_20241007_105803](https://github.com/user-attachments/assets/f2d91976-be87-428b-b604-a211399b821c)

Furthermore, the query was searching for events which are not always
applicable. For example, there was always a session ID search and if
there was a user ID, we would also search for a browser fingerprint in
event payload (expensive). Even if those argument string would be empty.

This PR changes:

1. Nest the position query, so that a full `instance_id, aggregate_id,
aggregate_type, event_type, "position"` index can be matched.
2. Redefine the `es_wm` index to include the `position` column.
3. Only search for events for the IDs that actually have a value. Do not
search (noop) if none of session ID, user ID or fingerpint ID are set.

New query plan:


![Screenshot_20241007_110648](https://github.com/user-attachments/assets/c3234c33-1b76-4b33-a4a9-796f69f3d775)


# Additional Changes

- cleanup how we load multi-statement migrations and make that a bit
more reusable.

# Additional Context

- Related to https://github.com/zitadel/zitadel/issues/7639
This commit is contained in:
Tim Möhlmann 2024-10-07 15:49:55 +03:00 committed by GitHub
parent 2bd3f44094
commit a84b259e8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 148 additions and 54 deletions

View File

@ -25,15 +25,11 @@ type NewEventsTable struct {
}
func (mig *NewEventsTable) Execute(ctx context.Context, _ eventstore.Event) error {
migrations, err := newEventsTable.ReadDir("14/" + mig.dbClient.Type())
if err != nil {
return err
}
// if events already exists events2 is created during a setup job
var count int
err = mig.dbClient.QueryRow(
err := mig.dbClient.QueryRowContext(ctx,
func(row *sql.Row) error {
if err = row.Scan(&count); err != nil {
if err := row.Scan(&count); err != nil {
return err
}
return row.Err()
@ -43,16 +39,15 @@ func (mig *NewEventsTable) Execute(ctx context.Context, _ eventstore.Event) erro
if err != nil || count == 1 {
return err
}
for _, migration := range migrations {
stmt, err := readStmt(newEventsTable, "14", mig.dbClient.Type(), migration.Name())
if err != nil {
return err
}
stmt = strings.ReplaceAll(stmt, "{{.username}}", mig.dbClient.Username())
logging.WithFields("migration", mig.String(), "file", migration.Name()).Debug("execute statement")
_, err = mig.dbClient.ExecContext(ctx, stmt)
statements, err := readStatements(newEventsTable, "14", mig.dbClient.Type())
if err != nil {
return err
}
for _, stmt := range statements {
stmt.query = strings.ReplaceAll(stmt.query, "{{.username}}", mig.dbClient.Username())
logging.WithFields("file", stmt.file, "migration", mig.String()).Info("execute statement")
_, err = mig.dbClient.ExecContext(ctx, stmt.query)
if err != nil {
return err
}

View File

@ -21,19 +21,13 @@ type CurrentProjectionState struct {
}
func (mig *CurrentProjectionState) Execute(ctx context.Context, _ eventstore.Event) error {
migrations, err := currentProjectionState.ReadDir("15/" + mig.dbClient.Type())
statements, err := readStatements(currentProjectionState, "15", mig.dbClient.Type())
if err != nil {
return err
}
for _, migration := range migrations {
stmt, err := readStmt(currentProjectionState, "15", mig.dbClient.Type(), migration.Name())
if err != nil {
return err
}
logging.WithFields("file", migration.Name(), "migration", mig.String()).Info("execute statement")
_, err = mig.dbClient.ExecContext(ctx, stmt)
for _, stmt := range statements {
logging.WithFields("file", stmt.file, "migration", mig.String()).Info("execute statement")
_, err = mig.dbClient.ExecContext(ctx, stmt.query)
if err != nil {
return err
}

39
cmd/setup/35.go Normal file
View File

@ -0,0 +1,39 @@
package setup
import (
"context"
"embed"
"fmt"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
)
var (
//go:embed 35/*.sql
addPositionToEsWmIndex embed.FS
)
type AddPositionToIndexEsWm struct {
dbClient *database.DB
}
func (mig *AddPositionToIndexEsWm) Execute(ctx context.Context, _ eventstore.Event) error {
statements, err := readStatements(addPositionToEsWmIndex, "35", "")
if err != nil {
return err
}
for _, stmt := range statements {
logging.WithFields("file", stmt.file, "migration", mig.String()).Info("execute statement")
if _, err := mig.dbClient.ExecContext(ctx, stmt.query); err != nil {
return fmt.Errorf("%s %s: %w", mig.String(), stmt.file, err)
}
}
return nil
}
func (mig *AddPositionToIndexEsWm) String() string {
return "35_add_position_to_index_es_wm"
}

View File

@ -0,0 +1,2 @@
CREATE INDEX CONCURRENTLY IF NOT EXISTS es_wm_temp
ON eventstore.events2 (instance_id, aggregate_id, aggregate_type, event_type, "position");

View File

@ -0,0 +1 @@
DROP INDEX IF EXISTS eventstore.es_wm;

View File

@ -0,0 +1 @@
ALTER INDEX eventstore.es_wm_temp RENAME TO es_wm;

View File

@ -121,6 +121,7 @@ type Steps struct {
s32AddAuthSessionID *AddAuthSessionID
s33SMSConfigs3TwilioAddVerifyServiceSid *SMSConfigs3TwilioAddVerifyServiceSid
s34AddCacheSchema *AddCacheSchema
s35AddPositionToIndexEsWm *AddPositionToIndexEsWm
}
func MustNewSteps(v *viper.Viper) *Steps {

View File

@ -5,6 +5,7 @@ import (
"embed"
_ "embed"
"net/http"
"path/filepath"
"github.com/spf13/cobra"
"github.com/spf13/viper"
@ -163,6 +164,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
steps.s32AddAuthSessionID = &AddAuthSessionID{dbClient: esPusherDBClient}
steps.s33SMSConfigs3TwilioAddVerifyServiceSid = &SMSConfigs3TwilioAddVerifyServiceSid{dbClient: esPusherDBClient}
steps.s34AddCacheSchema = &AddCacheSchema{dbClient: queryDBClient}
steps.s35AddPositionToIndexEsWm = &AddPositionToIndexEsWm{dbClient: esPusherDBClient}
err = projection.Create(ctx, projectionDBClient, eventstoreClient, config.Projections, nil, nil, nil)
logging.OnError(err).Fatal("unable to start projections")
@ -206,6 +208,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
steps.s29FillFieldsForProjectGrant,
steps.s30FillFieldsForOrgDomainVerified,
steps.s34AddCacheSchema,
steps.s35AddPositionToIndexEsWm,
} {
mustExecuteMigration(ctx, eventstoreClient, step, "migration failed")
}
@ -245,11 +248,41 @@ func mustExecuteMigration(ctx context.Context, eventstoreClient *eventstore.Even
logging.WithFields("name", step.String()).OnError(err).Fatal(errorMsg)
}
// readStmt reads a single file from the embedded FS,
// under the folder/typ/filename path.
// Typ describes the database dialect and may be omitted if no
// dialect specific migration is specified.
func readStmt(fs embed.FS, folder, typ, filename string) (string, error) {
stmt, err := fs.ReadFile(folder + "/" + typ + "/" + filename)
stmt, err := fs.ReadFile(filepath.Join(folder, typ, filename))
return string(stmt), err
}
type statement struct {
file string
query string
}
// readStatements reads all files from the embedded FS,
// under the folder/type path.
// Typ describes the database dialect and may be omitted if no
// dialect specific migration is specified.
func readStatements(fs embed.FS, folder, typ string) ([]statement, error) {
basePath := filepath.Join(folder, typ)
dir, err := fs.ReadDir(basePath)
if err != nil {
return nil, err
}
statements := make([]statement, len(dir))
for i, file := range dir {
statements[i].file = file.Name()
statements[i].query, err = readStmt(fs, folder, typ, file.Name())
if err != nil {
return nil, err
}
}
return statements, nil
}
func initProjections(
ctx context.Context,
eventstoreClient *eventstore.Eventstore,

View File

@ -156,6 +156,7 @@ func QueryFromBuilder(builder *eventstore.SearchQueryBuilder) (*SearchQuery, err
aggregateIDFilter,
eventTypeFilter,
eventDataFilter,
eventPositionAfterFilter,
} {
filter := f(q)
if filter == nil {
@ -275,3 +276,10 @@ func eventDataFilter(query *eventstore.SearchQuery) *Filter {
}
return NewFilter(FieldEventData, query.GetEventData(), OperationJSONContains)
}
func eventPositionAfterFilter(query *eventstore.SearchQuery) *Filter {
if pos := query.GetPositionAfter(); pos != 0 {
return NewFilter(FieldPosition, pos, OperationGreater)
}
return nil
}

View File

@ -107,6 +107,7 @@ type SearchQuery struct {
aggregateIDs []string
eventTypes []EventType
eventData map[string]interface{}
positionAfter float64
}
func (q SearchQuery) GetAggregateTypes() []AggregateType {
@ -125,6 +126,10 @@ func (q SearchQuery) GetEventData() map[string]interface{} {
return q.eventData
}
func (q SearchQuery) GetPositionAfter() float64 {
return q.positionAfter
}
// Columns defines which fields of the event are needed for the query
type Columns int8
@ -344,6 +349,11 @@ func (query *SearchQuery) EventData(data map[string]interface{}) *SearchQuery {
return query
}
func (query *SearchQuery) PositionAfter(position float64) *SearchQuery {
query.positionAfter = position
return query
}
// Builder returns the SearchQueryBuilder of the sub query
func (query *SearchQuery) Builder() *SearchQueryBuilder {
return query.builder

View File

@ -144,6 +144,9 @@ func (q *Queries) checkSessionNotTerminatedAfter(ctx context.Context, sessionID,
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if sessionID == "" && userID == "" && fingerprintID == "" {
return nil
}
model := &sessionTerminatedModel{
sessionID: sessionID,
position: position,
@ -181,33 +184,40 @@ func (s *sessionTerminatedModel) AppendEvents(events ...eventstore.Event) {
}
func (s *sessionTerminatedModel) Query() *eventstore.SearchQueryBuilder {
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
PositionAfter(s.position).
AddQuery().
AggregateTypes(session.AggregateType).
AggregateIDs(s.sessionID).
EventTypes(
session.TerminateType,
).
Builder()
if s.userID == "" {
return query
builder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent)
if s.sessionID != "" {
builder = builder.AddQuery().
AggregateTypes(session.AggregateType).
AggregateIDs(s.sessionID).
EventTypes(
session.TerminateType,
).
PositionAfter(s.position).
Builder()
}
return query.
AddQuery().
AggregateTypes(user.AggregateType).
AggregateIDs(s.userID).
EventTypes(
user.UserDeactivatedType,
user.UserLockedType,
user.UserRemovedType,
).
Or(). // for specific logout on v1 sessions from the same user agent
AggregateTypes(user.AggregateType).
AggregateIDs(s.userID).
EventTypes(
user.HumanSignedOutType,
).
EventData(map[string]interface{}{"userAgentID": s.fingerPrintID}).
Builder()
if s.userID != "" {
builder = builder.AddQuery().
AggregateTypes(user.AggregateType).
AggregateIDs(s.userID).
EventTypes(
user.UserDeactivatedType,
user.UserLockedType,
user.UserRemovedType,
).
PositionAfter(s.position).
Builder()
if s.fingerPrintID != "" {
// for specific logout on v1 sessions from the same user agent
builder = builder.AddQuery().
AggregateTypes(user.AggregateType).
AggregateIDs(s.userID).
EventTypes(
user.HumanSignedOutType,
).
EventData(map[string]interface{}{"userAgentID": s.fingerPrintID}).
PositionAfter(s.position).
Builder()
}
}
return builder
}