mirror of
https://github.com/zitadel/zitadel.git
synced 2025-06-03 17:39:55 +00:00
perf(oidc): nest position clause for session terminated query (#8738)
# Which Problems Are Solved Optimize the query that checks for terminated sessions in the access token verifier. The verifier is used in auth middleware, userinfo and introspection. # How the Problems Are Solved The previous implementation built a query for certain events and then appended a single `PositionAfter` clause. This caused the postgreSQL planner to use indexes only for the instance ID, aggregate IDs, aggregate types and event types. Followed by an expensive sequential scan for the position. This resulting in internal over-fetching of rows before the final filter was applied.  Furthermore, the query was searching for events which are not always applicable. For example, there was always a session ID search and if there was a user ID, we would also search for a browser fingerprint in event payload (expensive). Even if those argument string would be empty. This PR changes: 1. Nest the position query, so that a full `instance_id, aggregate_id, aggregate_type, event_type, "position"` index can be matched. 2. Redefine the `es_wm` index to include the `position` column. 3. Only search for events for the IDs that actually have a value. Do not search (noop) if none of session ID, user ID or fingerpint ID are set. New query plan:  # Additional Changes - cleanup how we load multi-statement migrations and make that a bit more reusable. # Additional Context - Related to https://github.com/zitadel/zitadel/issues/7639
This commit is contained in:
parent
2bd3f44094
commit
a84b259e8c
@ -25,15 +25,11 @@ type NewEventsTable struct {
|
||||
}
|
||||
|
||||
func (mig *NewEventsTable) Execute(ctx context.Context, _ eventstore.Event) error {
|
||||
migrations, err := newEventsTable.ReadDir("14/" + mig.dbClient.Type())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// if events already exists events2 is created during a setup job
|
||||
var count int
|
||||
err = mig.dbClient.QueryRow(
|
||||
err := mig.dbClient.QueryRowContext(ctx,
|
||||
func(row *sql.Row) error {
|
||||
if err = row.Scan(&count); err != nil {
|
||||
if err := row.Scan(&count); err != nil {
|
||||
return err
|
||||
}
|
||||
return row.Err()
|
||||
@ -43,16 +39,15 @@ func (mig *NewEventsTable) Execute(ctx context.Context, _ eventstore.Event) erro
|
||||
if err != nil || count == 1 {
|
||||
return err
|
||||
}
|
||||
for _, migration := range migrations {
|
||||
stmt, err := readStmt(newEventsTable, "14", mig.dbClient.Type(), migration.Name())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stmt = strings.ReplaceAll(stmt, "{{.username}}", mig.dbClient.Username())
|
||||
|
||||
logging.WithFields("migration", mig.String(), "file", migration.Name()).Debug("execute statement")
|
||||
|
||||
_, err = mig.dbClient.ExecContext(ctx, stmt)
|
||||
statements, err := readStatements(newEventsTable, "14", mig.dbClient.Type())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, stmt := range statements {
|
||||
stmt.query = strings.ReplaceAll(stmt.query, "{{.username}}", mig.dbClient.Username())
|
||||
logging.WithFields("file", stmt.file, "migration", mig.String()).Info("execute statement")
|
||||
_, err = mig.dbClient.ExecContext(ctx, stmt.query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -21,19 +21,13 @@ type CurrentProjectionState struct {
|
||||
}
|
||||
|
||||
func (mig *CurrentProjectionState) Execute(ctx context.Context, _ eventstore.Event) error {
|
||||
migrations, err := currentProjectionState.ReadDir("15/" + mig.dbClient.Type())
|
||||
statements, err := readStatements(currentProjectionState, "15", mig.dbClient.Type())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, migration := range migrations {
|
||||
stmt, err := readStmt(currentProjectionState, "15", mig.dbClient.Type(), migration.Name())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logging.WithFields("file", migration.Name(), "migration", mig.String()).Info("execute statement")
|
||||
|
||||
_, err = mig.dbClient.ExecContext(ctx, stmt)
|
||||
for _, stmt := range statements {
|
||||
logging.WithFields("file", stmt.file, "migration", mig.String()).Info("execute statement")
|
||||
_, err = mig.dbClient.ExecContext(ctx, stmt.query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
39
cmd/setup/35.go
Normal file
39
cmd/setup/35.go
Normal file
@ -0,0 +1,39 @@
|
||||
package setup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"fmt"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
var (
|
||||
//go:embed 35/*.sql
|
||||
addPositionToEsWmIndex embed.FS
|
||||
)
|
||||
|
||||
type AddPositionToIndexEsWm struct {
|
||||
dbClient *database.DB
|
||||
}
|
||||
|
||||
func (mig *AddPositionToIndexEsWm) Execute(ctx context.Context, _ eventstore.Event) error {
|
||||
statements, err := readStatements(addPositionToEsWmIndex, "35", "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, stmt := range statements {
|
||||
logging.WithFields("file", stmt.file, "migration", mig.String()).Info("execute statement")
|
||||
if _, err := mig.dbClient.ExecContext(ctx, stmt.query); err != nil {
|
||||
return fmt.Errorf("%s %s: %w", mig.String(), stmt.file, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mig *AddPositionToIndexEsWm) String() string {
|
||||
return "35_add_position_to_index_es_wm"
|
||||
}
|
2
cmd/setup/35/00_create_index.sql
Normal file
2
cmd/setup/35/00_create_index.sql
Normal file
@ -0,0 +1,2 @@
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS es_wm_temp
|
||||
ON eventstore.events2 (instance_id, aggregate_id, aggregate_type, event_type, "position");
|
1
cmd/setup/35/01_drop_index.sql
Normal file
1
cmd/setup/35/01_drop_index.sql
Normal file
@ -0,0 +1 @@
|
||||
DROP INDEX IF EXISTS eventstore.es_wm;
|
1
cmd/setup/35/02_alter_index.sql
Normal file
1
cmd/setup/35/02_alter_index.sql
Normal file
@ -0,0 +1 @@
|
||||
ALTER INDEX eventstore.es_wm_temp RENAME TO es_wm;
|
@ -121,6 +121,7 @@ type Steps struct {
|
||||
s32AddAuthSessionID *AddAuthSessionID
|
||||
s33SMSConfigs3TwilioAddVerifyServiceSid *SMSConfigs3TwilioAddVerifyServiceSid
|
||||
s34AddCacheSchema *AddCacheSchema
|
||||
s35AddPositionToIndexEsWm *AddPositionToIndexEsWm
|
||||
}
|
||||
|
||||
func MustNewSteps(v *viper.Viper) *Steps {
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"embed"
|
||||
_ "embed"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
@ -163,6 +164,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
|
||||
steps.s32AddAuthSessionID = &AddAuthSessionID{dbClient: esPusherDBClient}
|
||||
steps.s33SMSConfigs3TwilioAddVerifyServiceSid = &SMSConfigs3TwilioAddVerifyServiceSid{dbClient: esPusherDBClient}
|
||||
steps.s34AddCacheSchema = &AddCacheSchema{dbClient: queryDBClient}
|
||||
steps.s35AddPositionToIndexEsWm = &AddPositionToIndexEsWm{dbClient: esPusherDBClient}
|
||||
|
||||
err = projection.Create(ctx, projectionDBClient, eventstoreClient, config.Projections, nil, nil, nil)
|
||||
logging.OnError(err).Fatal("unable to start projections")
|
||||
@ -206,6 +208,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
|
||||
steps.s29FillFieldsForProjectGrant,
|
||||
steps.s30FillFieldsForOrgDomainVerified,
|
||||
steps.s34AddCacheSchema,
|
||||
steps.s35AddPositionToIndexEsWm,
|
||||
} {
|
||||
mustExecuteMigration(ctx, eventstoreClient, step, "migration failed")
|
||||
}
|
||||
@ -245,11 +248,41 @@ func mustExecuteMigration(ctx context.Context, eventstoreClient *eventstore.Even
|
||||
logging.WithFields("name", step.String()).OnError(err).Fatal(errorMsg)
|
||||
}
|
||||
|
||||
// readStmt reads a single file from the embedded FS,
|
||||
// under the folder/typ/filename path.
|
||||
// Typ describes the database dialect and may be omitted if no
|
||||
// dialect specific migration is specified.
|
||||
func readStmt(fs embed.FS, folder, typ, filename string) (string, error) {
|
||||
stmt, err := fs.ReadFile(folder + "/" + typ + "/" + filename)
|
||||
stmt, err := fs.ReadFile(filepath.Join(folder, typ, filename))
|
||||
return string(stmt), err
|
||||
}
|
||||
|
||||
type statement struct {
|
||||
file string
|
||||
query string
|
||||
}
|
||||
|
||||
// readStatements reads all files from the embedded FS,
|
||||
// under the folder/type path.
|
||||
// Typ describes the database dialect and may be omitted if no
|
||||
// dialect specific migration is specified.
|
||||
func readStatements(fs embed.FS, folder, typ string) ([]statement, error) {
|
||||
basePath := filepath.Join(folder, typ)
|
||||
dir, err := fs.ReadDir(basePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
statements := make([]statement, len(dir))
|
||||
for i, file := range dir {
|
||||
statements[i].file = file.Name()
|
||||
statements[i].query, err = readStmt(fs, folder, typ, file.Name())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return statements, nil
|
||||
}
|
||||
|
||||
func initProjections(
|
||||
ctx context.Context,
|
||||
eventstoreClient *eventstore.Eventstore,
|
||||
|
@ -156,6 +156,7 @@ func QueryFromBuilder(builder *eventstore.SearchQueryBuilder) (*SearchQuery, err
|
||||
aggregateIDFilter,
|
||||
eventTypeFilter,
|
||||
eventDataFilter,
|
||||
eventPositionAfterFilter,
|
||||
} {
|
||||
filter := f(q)
|
||||
if filter == nil {
|
||||
@ -275,3 +276,10 @@ func eventDataFilter(query *eventstore.SearchQuery) *Filter {
|
||||
}
|
||||
return NewFilter(FieldEventData, query.GetEventData(), OperationJSONContains)
|
||||
}
|
||||
|
||||
func eventPositionAfterFilter(query *eventstore.SearchQuery) *Filter {
|
||||
if pos := query.GetPositionAfter(); pos != 0 {
|
||||
return NewFilter(FieldPosition, pos, OperationGreater)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -107,6 +107,7 @@ type SearchQuery struct {
|
||||
aggregateIDs []string
|
||||
eventTypes []EventType
|
||||
eventData map[string]interface{}
|
||||
positionAfter float64
|
||||
}
|
||||
|
||||
func (q SearchQuery) GetAggregateTypes() []AggregateType {
|
||||
@ -125,6 +126,10 @@ func (q SearchQuery) GetEventData() map[string]interface{} {
|
||||
return q.eventData
|
||||
}
|
||||
|
||||
func (q SearchQuery) GetPositionAfter() float64 {
|
||||
return q.positionAfter
|
||||
}
|
||||
|
||||
// Columns defines which fields of the event are needed for the query
|
||||
type Columns int8
|
||||
|
||||
@ -344,6 +349,11 @@ func (query *SearchQuery) EventData(data map[string]interface{}) *SearchQuery {
|
||||
return query
|
||||
}
|
||||
|
||||
func (query *SearchQuery) PositionAfter(position float64) *SearchQuery {
|
||||
query.positionAfter = position
|
||||
return query
|
||||
}
|
||||
|
||||
// Builder returns the SearchQueryBuilder of the sub query
|
||||
func (query *SearchQuery) Builder() *SearchQueryBuilder {
|
||||
return query.builder
|
||||
|
@ -144,6 +144,9 @@ func (q *Queries) checkSessionNotTerminatedAfter(ctx context.Context, sessionID,
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
if sessionID == "" && userID == "" && fingerprintID == "" {
|
||||
return nil
|
||||
}
|
||||
model := &sessionTerminatedModel{
|
||||
sessionID: sessionID,
|
||||
position: position,
|
||||
@ -181,33 +184,40 @@ func (s *sessionTerminatedModel) AppendEvents(events ...eventstore.Event) {
|
||||
}
|
||||
|
||||
func (s *sessionTerminatedModel) Query() *eventstore.SearchQueryBuilder {
|
||||
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
PositionAfter(s.position).
|
||||
AddQuery().
|
||||
AggregateTypes(session.AggregateType).
|
||||
AggregateIDs(s.sessionID).
|
||||
EventTypes(
|
||||
session.TerminateType,
|
||||
).
|
||||
Builder()
|
||||
if s.userID == "" {
|
||||
return query
|
||||
builder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent)
|
||||
if s.sessionID != "" {
|
||||
builder = builder.AddQuery().
|
||||
AggregateTypes(session.AggregateType).
|
||||
AggregateIDs(s.sessionID).
|
||||
EventTypes(
|
||||
session.TerminateType,
|
||||
).
|
||||
PositionAfter(s.position).
|
||||
Builder()
|
||||
}
|
||||
return query.
|
||||
AddQuery().
|
||||
AggregateTypes(user.AggregateType).
|
||||
AggregateIDs(s.userID).
|
||||
EventTypes(
|
||||
user.UserDeactivatedType,
|
||||
user.UserLockedType,
|
||||
user.UserRemovedType,
|
||||
).
|
||||
Or(). // for specific logout on v1 sessions from the same user agent
|
||||
AggregateTypes(user.AggregateType).
|
||||
AggregateIDs(s.userID).
|
||||
EventTypes(
|
||||
user.HumanSignedOutType,
|
||||
).
|
||||
EventData(map[string]interface{}{"userAgentID": s.fingerPrintID}).
|
||||
Builder()
|
||||
if s.userID != "" {
|
||||
builder = builder.AddQuery().
|
||||
AggregateTypes(user.AggregateType).
|
||||
AggregateIDs(s.userID).
|
||||
EventTypes(
|
||||
user.UserDeactivatedType,
|
||||
user.UserLockedType,
|
||||
user.UserRemovedType,
|
||||
).
|
||||
PositionAfter(s.position).
|
||||
Builder()
|
||||
if s.fingerPrintID != "" {
|
||||
// for specific logout on v1 sessions from the same user agent
|
||||
builder = builder.AddQuery().
|
||||
AggregateTypes(user.AggregateType).
|
||||
AggregateIDs(s.userID).
|
||||
EventTypes(
|
||||
user.HumanSignedOutType,
|
||||
).
|
||||
EventData(map[string]interface{}{"userAgentID": s.fingerPrintID}).
|
||||
PositionAfter(s.position).
|
||||
Builder()
|
||||
}
|
||||
}
|
||||
return builder
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user