mirror of
https://github.com/zitadel/zitadel.git
synced 2025-05-02 15:00:59 +00:00
fix(auth): efficient user session projection (#7187)
* fix(auth): cache users during session projection * fix(auth.user_sessions): add index for more efficient by user search
This commit is contained in:
parent
039a1e793b
commit
43f1d59649
26
cmd/setup/20.go
Normal file
26
cmd/setup/20.go
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
package setup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
_ "embed"
|
||||||
|
|
||||||
|
"github.com/zitadel/zitadel/internal/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
//go:embed 20.sql
|
||||||
|
addByUserIndexToSession string
|
||||||
|
)
|
||||||
|
|
||||||
|
type AddByUserIndexToSession struct {
|
||||||
|
dbClient *database.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mig *AddByUserIndexToSession) Execute(ctx context.Context) error {
|
||||||
|
_, err := mig.dbClient.ExecContext(ctx, addByUserIndexToSession)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mig *AddByUserIndexToSession) String() string {
|
||||||
|
return "20_add_by_user_index_on_session"
|
||||||
|
}
|
1
cmd/setup/20.sql
Normal file
1
cmd/setup/20.sql
Normal file
@ -0,0 +1 @@
|
|||||||
|
CREATE INDEX CONCURRENTLY IF NOT EXISTS user_sessions_by_user ON auth.user_sessions (instance_id, user_id);
|
@ -77,6 +77,7 @@ type Steps struct {
|
|||||||
s17AddOffsetToUniqueConstraints *AddOffsetToCurrentStates
|
s17AddOffsetToUniqueConstraints *AddOffsetToCurrentStates
|
||||||
s18AddLowerFieldsToLoginNames *AddLowerFieldsToLoginNames
|
s18AddLowerFieldsToLoginNames *AddLowerFieldsToLoginNames
|
||||||
s19AddCurrentStatesIndex *AddCurrentSequencesIndex
|
s19AddCurrentStatesIndex *AddCurrentSequencesIndex
|
||||||
|
s20AddByUserSessionIndex *AddByUserIndexToSession
|
||||||
}
|
}
|
||||||
|
|
||||||
type encryptionKeyConfig struct {
|
type encryptionKeyConfig struct {
|
||||||
|
@ -110,6 +110,7 @@ func Setup(config *Config, steps *Steps, masterKey string) {
|
|||||||
steps.s17AddOffsetToUniqueConstraints = &AddOffsetToCurrentStates{dbClient: queryDBClient}
|
steps.s17AddOffsetToUniqueConstraints = &AddOffsetToCurrentStates{dbClient: queryDBClient}
|
||||||
steps.s18AddLowerFieldsToLoginNames = &AddLowerFieldsToLoginNames{dbClient: queryDBClient}
|
steps.s18AddLowerFieldsToLoginNames = &AddLowerFieldsToLoginNames{dbClient: queryDBClient}
|
||||||
steps.s19AddCurrentStatesIndex = &AddCurrentSequencesIndex{dbClient: queryDBClient}
|
steps.s19AddCurrentStatesIndex = &AddCurrentSequencesIndex{dbClient: queryDBClient}
|
||||||
|
steps.s20AddByUserSessionIndex = &AddByUserIndexToSession{dbClient: queryDBClient}
|
||||||
|
|
||||||
err = projection.Create(ctx, projectionDBClient, eventstoreClient, config.Projections, nil, nil, nil)
|
err = projection.Create(ctx, projectionDBClient, eventstoreClient, config.Projections, nil, nil, nil)
|
||||||
logging.OnError(err).Fatal("unable to start projections")
|
logging.OnError(err).Fatal("unable to start projections")
|
||||||
@ -156,6 +157,8 @@ func Setup(config *Config, steps *Steps, masterKey string) {
|
|||||||
logging.WithFields("name", steps.s17AddOffsetToUniqueConstraints.String()).OnError(err).Fatal("migration failed")
|
logging.WithFields("name", steps.s17AddOffsetToUniqueConstraints.String()).OnError(err).Fatal("migration failed")
|
||||||
err = migration.Migrate(ctx, eventstoreClient, steps.s19AddCurrentStatesIndex)
|
err = migration.Migrate(ctx, eventstoreClient, steps.s19AddCurrentStatesIndex)
|
||||||
logging.WithFields("name", steps.s19AddCurrentStatesIndex.String()).OnError(err).Fatal("migration failed")
|
logging.WithFields("name", steps.s19AddCurrentStatesIndex.String()).OnError(err).Fatal("migration failed")
|
||||||
|
err = migration.Migrate(ctx, eventstoreClient, steps.s20AddByUserSessionIndex)
|
||||||
|
logging.WithFields("name", steps.s20AddByUserSessionIndex.String()).OnError(err).Fatal("migration failed")
|
||||||
|
|
||||||
for _, repeatableStep := range repeatableSteps {
|
for _, repeatableStep := range repeatableSteps {
|
||||||
err = migration.Migrate(ctx, eventstoreClient, repeatableStep)
|
err = migration.Migrate(ctx, eventstoreClient, repeatableStep)
|
||||||
|
@ -296,11 +296,24 @@ func (u *UserSession) Reduce(event eventstore.Event) (_ *handler.Statement, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *UserSession) appendEventOnSessions(sessions []*view_model.UserSessionView, event eventstore.Event) error {
|
func (u *UserSession) appendEventOnSessions(sessions []*view_model.UserSessionView, event eventstore.Event) error {
|
||||||
|
users := make(map[string]*view_model.UserView)
|
||||||
|
usersByID := func(userID, instanceID string) (user *view_model.UserView, err error) {
|
||||||
|
user, ok := users[userID+"-"+instanceID]
|
||||||
|
if ok {
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
users[userID+"-"+instanceID], err = u.view.UserByID(userID, instanceID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return users[userID+"-"+instanceID], nil
|
||||||
|
}
|
||||||
for _, session := range sessions {
|
for _, session := range sessions {
|
||||||
if err := session.AppendEvent(event); err != nil {
|
if err := session.AppendEvent(event); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := u.fillUserInfo(session); err != nil {
|
if err := u.fillUserInfo(session, usersByID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -311,7 +324,7 @@ func (u *UserSession) updateSession(session *view_model.UserSessionView, event e
|
|||||||
if err := session.AppendEvent(event); err != nil {
|
if err := session.AppendEvent(event); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := u.fillUserInfo(session); err != nil {
|
if err := u.fillUserInfo(session, u.view.UserByID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := u.view.PutUserSession(session); err != nil {
|
if err := u.view.PutUserSession(session); err != nil {
|
||||||
@ -320,8 +333,8 @@ func (u *UserSession) updateSession(session *view_model.UserSessionView, event e
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *UserSession) fillUserInfo(session *view_model.UserSessionView) error {
|
func (u *UserSession) fillUserInfo(session *view_model.UserSessionView, getUserByID func(userID, instanceID string) (*view_model.UserView, error)) error {
|
||||||
user, err := u.view.UserByID(session.UserID, session.InstanceID)
|
user, err := getUserByID(session.UserID, session.InstanceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,94 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
view_model "github.com/zitadel/zitadel/internal/user/repository/view/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// tests the proper working of the cache function
|
||||||
|
func TestUserSession_fillUserInfo(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
sessions []*view_model.UserSessionView
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
cacheHits map[string]int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "one session",
|
||||||
|
args: args{
|
||||||
|
sessions: []*view_model.UserSessionView{
|
||||||
|
{
|
||||||
|
UserID: "user",
|
||||||
|
InstanceID: "instance",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
cacheHits: map[string]int{
|
||||||
|
"user-instance": 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "same user",
|
||||||
|
args: args{
|
||||||
|
sessions: []*view_model.UserSessionView{
|
||||||
|
{
|
||||||
|
UserID: "user",
|
||||||
|
InstanceID: "instance",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
UserID: "user",
|
||||||
|
InstanceID: "instance",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
cacheHits: map[string]int{
|
||||||
|
"user-instance": 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "different users",
|
||||||
|
args: args{
|
||||||
|
sessions: []*view_model.UserSessionView{
|
||||||
|
{
|
||||||
|
UserID: "user",
|
||||||
|
InstanceID: "instance",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
UserID: "user2",
|
||||||
|
InstanceID: "instance",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
cacheHits: map[string]int{
|
||||||
|
"user-instance": 1,
|
||||||
|
"user2-instance": 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cache := map[string]int{}
|
||||||
|
getUserByID := func(userID, instanceID string) (*view_model.UserView, error) {
|
||||||
|
cache[userID+"-"+instanceID]++
|
||||||
|
return &view_model.UserView{HumanView: &view_model.HumanView{}}, nil
|
||||||
|
}
|
||||||
|
for _, session := range tt.args.sessions {
|
||||||
|
if err := new(UserSession).fillUserInfo(session, getUserByID); err != nil {
|
||||||
|
t.Errorf("UserSession.fillUserInfo() unexpected error = %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(cache) != len(tt.cacheHits) {
|
||||||
|
t.Errorf("unexpected length of cache hits: want %d, got %d", len(tt.cacheHits), len(cache))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for key, count := range tt.cacheHits {
|
||||||
|
if cache[key] != count {
|
||||||
|
t.Errorf("unexpected cache hits on %s: want %d, got %d", key, count, cache[key])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user