From aa407c3c3ecab31014f75551d39ef66ecbc9d3f5 Mon Sep 17 00:00:00 2001 From: Silvan Date: Tue, 30 Jan 2024 16:17:54 +0100 Subject: [PATCH] fix(auth): optimise user sessions (#7199) * fix(auth): start optimise user sessions * reduce and query user sessions directly without gorm statements * cleanup * cleanup * fix requested changes --------- Co-authored-by: Livio Spring --- .../repository/eventsourcing/handler/token.go | 3 +- .../eventsourcing/handler/user_session.go | 304 +++++++----------- .../handler/user_session_test.go | 94 ------ .../eventsourcing/view/user_session.go | 20 +- .../repository/view/model/user_session.go | 43 +-- .../repository/view/user_session_by_id.sql | 29 ++ .../user/repository/view/user_session_view.go | 209 ++++++------ .../view/user_sessions_by_user_agent.sql | 27 ++ 8 files changed, 292 insertions(+), 437 deletions(-) delete mode 100644 internal/auth/repository/eventsourcing/handler/user_session_test.go create mode 100644 internal/user/repository/view/user_session_by_id.sql create mode 100644 internal/user/repository/view/user_sessions_by_user_agent.sql diff --git a/internal/auth/repository/eventsourcing/handler/token.go b/internal/auth/repository/eventsourcing/handler/token.go index 14641cacdc..4b0a7194de 100644 --- a/internal/auth/repository/eventsourcing/handler/token.go +++ b/internal/auth/repository/eventsourcing/handler/token.go @@ -244,7 +244,8 @@ func agentIDFromSession(event eventstore.Event) (string, error) { logging.WithError(err).Error("could not unmarshal event data") return "", zerrors.ThrowInternal(nil, "MODEL-sd325", "could not unmarshal data") } - return session["userAgentID"].(string), nil + agentID, _ := session["userAgentID"].(string) + return agentID, nil } func applicationFromSession(event eventstore.Event) (*project_es_model.Application, error) { diff --git a/internal/auth/repository/eventsourcing/handler/user_session.go b/internal/auth/repository/eventsourcing/handler/user_session.go index dfc2774c89..12cb98f9db 100644 --- a/internal/auth/repository/eventsourcing/handler/user_session.go +++ b/internal/auth/repository/eventsourcing/handler/user_session.go @@ -2,16 +2,12 @@ package handler import ( "context" + "time" - "github.com/zitadel/zitadel/internal/api/authz" auth_view "github.com/zitadel/zitadel/internal/auth/repository/eventsourcing/view" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore/handler/v2" - "github.com/zitadel/zitadel/internal/eventstore/v1/models" - org_model "github.com/zitadel/zitadel/internal/org/model" - org_es_model "github.com/zitadel/zitadel/internal/org/repository/eventsourcing/model" - org_view "github.com/zitadel/zitadel/internal/org/repository/view" query2 "github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/repository/instance" "github.com/zitadel/zitadel/internal/repository/org" @@ -128,10 +124,6 @@ func (s *UserSession) Reducers() []handler.AggregateReducer { Event: user.UserV1MFAOTPRemovedType, Reduce: s.Reduce, }, - { - Event: user.UserV1ProfileChangedType, - Reduce: s.Reduce, - }, { Event: user.UserLockedType, Reduce: s.Reduce, @@ -148,26 +140,6 @@ func (s *UserSession) Reducers() []handler.AggregateReducer { Event: user.HumanMFAOTPRemovedType, Reduce: s.Reduce, }, - { - Event: user.HumanProfileChangedType, - Reduce: s.Reduce, - }, - { - Event: user.HumanAvatarAddedType, - Reduce: s.Reduce, - }, - { - Event: user.HumanAvatarRemovedType, - Reduce: s.Reduce, - }, - { - Event: user.UserDomainClaimedType, - Reduce: s.Reduce, - }, - { - Event: user.UserUserNameChangedType, - Reduce: s.Reduce, - }, { Event: user.UserIDPLinkRemovedType, Reduce: s.Reduce, @@ -193,10 +165,6 @@ func (s *UserSession) Reducers() []handler.AggregateReducer { { Aggregate: org.AggregateType, EventReducers: []handler.EventReducer{ - { - Event: org.OrgDomainPrimarySetEventType, - Reduce: s.Reduce, - }, { Event: org.OrgRemovedEventType, Reduce: s.Reduce, @@ -216,25 +184,24 @@ func (s *UserSession) Reducers() []handler.AggregateReducer { } func (u *UserSession) Reduce(event eventstore.Event) (_ *handler.Statement, err error) { - return handler.NewStatement(event, func(ex handler.Executer, projectionName string) error { - var session *view_model.UserSessionView - switch event.Type() { - case user.UserV1PasswordCheckSucceededType, - user.UserV1PasswordCheckFailedType, - user.UserV1MFAOTPCheckSucceededType, - user.UserV1MFAOTPCheckFailedType, - user.UserV1SignedOutType, - user.HumanPasswordCheckSucceededType, - user.HumanPasswordCheckFailedType, - user.UserIDPLoginCheckSucceededType, - user.HumanMFAOTPCheckSucceededType, - user.HumanMFAOTPCheckFailedType, - user.HumanU2FTokenCheckSucceededType, - user.HumanU2FTokenCheckFailedType, - user.HumanPasswordlessTokenCheckSucceededType, - user.HumanPasswordlessTokenCheckFailedType, - user.HumanSignedOutType: - + switch event.Type() { + case user.UserV1PasswordCheckSucceededType, + user.UserV1PasswordCheckFailedType, + user.UserV1MFAOTPCheckSucceededType, + user.UserV1MFAOTPCheckFailedType, + user.UserV1SignedOutType, + user.HumanPasswordCheckSucceededType, + user.HumanPasswordCheckFailedType, + user.UserIDPLoginCheckSucceededType, + user.HumanMFAOTPCheckSucceededType, + user.HumanMFAOTPCheckFailedType, + user.HumanU2FTokenCheckSucceededType, + user.HumanU2FTokenCheckFailedType, + user.HumanPasswordlessTokenCheckSucceededType, + user.HumanPasswordlessTokenCheckFailedType, + user.HumanSignedOutType: + return handler.NewStatement(event, func(ex handler.Executer, projectionName string) error { + var session *view_model.UserSessionView eventData, err := view_model.UserSessionFromEvent(event) if err != nil { return err @@ -254,158 +221,103 @@ func (u *UserSession) Reduce(event eventstore.Event) (_ *handler.Statement, err } } return u.updateSession(session, event) - case user.UserV1PasswordChangedType, - user.UserV1MFAOTPRemovedType, - user.UserV1ProfileChangedType, - user.UserLockedType, - user.UserDeactivatedType, - user.HumanPasswordChangedType, - user.HumanMFAOTPRemovedType, - user.HumanProfileChangedType, - user.HumanAvatarAddedType, - user.HumanAvatarRemovedType, - user.UserDomainClaimedType, - user.UserUserNameChangedType, - user.UserIDPLinkRemovedType, - user.UserIDPLinkCascadeRemovedType, - user.HumanPasswordlessTokenRemovedType, - user.HumanU2FTokenRemovedType: - sessions, err := u.view.UserSessionsByUserID(event.Aggregate().ID, event.Aggregate().InstanceID) - if err != nil || len(sessions) == 0 { - return err - } - if err = u.appendEventOnSessions(sessions, event); err != nil { - return err - } - if err = u.view.PutUserSessions(sessions); err != nil { - return err - } - return nil - case org.OrgDomainPrimarySetEventType: - return u.fillLoginNamesOnOrgUsers(event) - case user.UserRemovedType: - return u.view.DeleteUserSessions(event.Aggregate().ID, event.Aggregate().InstanceID) - case instance.InstanceRemovedEventType: - return u.view.DeleteInstanceUserSessions(event.Aggregate().InstanceID) - case org.OrgRemovedEventType: - return u.view.DeleteOrgUserSessions(event) - default: - return nil - } - }), nil -} - -func (u *UserSession) appendEventOnSessions(sessions []*view_model.UserSessionView, event eventstore.Event) error { - users := make(map[string]*view_model.UserView) - usersByID := func(userID, instanceID string) (user *view_model.UserView, err error) { - user, ok := users[userID+"-"+instanceID] - if ok { - return user, nil - } - users[userID+"-"+instanceID], err = u.view.UserByID(userID, instanceID) + }), nil + case user.UserLockedType, + user.UserDeactivatedType: + return handler.NewUpdateStatement(event, + []handler.Column{ + handler.NewCol("passwordless_verification", time.Time{}), + handler.NewCol("password_verification", time.Time{}), + handler.NewCol("second_factor_verification", time.Time{}), + handler.NewCol("second_factor_verification_type", domain.MFALevelNotSetUp), + handler.NewCol("multi_factor_verification", time.Time{}), + handler.NewCol("multi_factor_verification_type", domain.MFALevelNotSetUp), + handler.NewCol("external_login_verification", time.Time{}), + handler.NewCol("state", domain.UserSessionStateTerminated), + }, + []handler.Condition{ + handler.NewCond("instance_id", event.Aggregate().InstanceID), + handler.NewCond("user_id", event.Aggregate().ID), + handler.Not(handler.NewCond("state", domain.UserSessionStateTerminated)), + }, + ), nil + case user.UserV1PasswordChangedType, + user.HumanPasswordChangedType: + userAgent, err := agentIDFromSession(event) if err != nil { return nil, err } - - return users[userID+"-"+instanceID], nil + return handler.NewUpdateStatement(event, + []handler.Column{ + handler.NewCol("password_verification", time.Time{}), + }, + []handler.Condition{ + handler.NewCond("instance_id", event.Aggregate().InstanceID), + handler.NewCond("user_id", event.Aggregate().ID), + handler.Not(handler.NewCond("user_agent_id", userAgent)), + handler.Not(handler.NewCond("state", domain.UserSessionStateTerminated)), + }, + ), nil + case user.UserV1MFAOTPRemovedType, + user.HumanMFAOTPRemovedType, + user.HumanU2FTokenRemovedType: + return handler.NewUpdateStatement(event, + []handler.Column{ + handler.NewCol("second_factor_verification", time.Time{}), + }, + []handler.Condition{ + handler.NewCond("instance_id", event.Aggregate().InstanceID), + handler.NewCond("user_id", event.Aggregate().ID), + handler.Not(handler.NewCond("state", domain.UserSessionStateTerminated)), + }, + ), nil + case user.UserIDPLinkRemovedType, + user.UserIDPLinkCascadeRemovedType: + return handler.NewUpdateStatement(event, + []handler.Column{ + handler.NewCol("external_login_verification", time.Time{}), + handler.NewCol("selected_idp_config_id", ""), + }, + []handler.Condition{ + handler.NewCond("instance_id", event.Aggregate().InstanceID), + handler.NewCond("user_id", event.Aggregate().ID), + handler.Not(handler.NewCond("selected_idp_config_id", "")), + }, + ), nil + case user.HumanPasswordlessTokenRemovedType: + return handler.NewUpdateStatement(event, + []handler.Column{ + handler.NewCol("passwordless_verification", time.Time{}), + handler.NewCol("multi_factor_verification", time.Time{}), + }, + []handler.Condition{ + handler.NewCond("instance_id", event.Aggregate().InstanceID), + handler.NewCond("user_id", event.Aggregate().ID), + handler.Not(handler.NewCond("state", domain.UserSessionStateTerminated)), + }, + ), nil + case user.UserRemovedType: + return handler.NewStatement(event, func(ex handler.Executer, projectionName string) error { + return u.view.DeleteUserSessions(event.Aggregate().ID, event.Aggregate().InstanceID) + }), nil + case instance.InstanceRemovedEventType: + return handler.NewStatement(event, func(ex handler.Executer, projectionName string) error { + return u.view.DeleteInstanceUserSessions(event.Aggregate().InstanceID) + }), nil + case org.OrgRemovedEventType: + return handler.NewStatement(event, func(ex handler.Executer, projectionName string) error { + return u.view.DeleteOrgUserSessions(event) + }), nil + default: + return handler.NewStatement(event, func(ex handler.Executer, projectionName string) error { + return nil + }), nil } - for _, session := range sessions { - if err := session.AppendEvent(event); err != nil { - return err - } - if err := u.fillUserInfo(session, usersByID); err != nil { - return err - } - } - return nil } func (u *UserSession) updateSession(session *view_model.UserSessionView, event eventstore.Event) error { if err := session.AppendEvent(event); err != nil { return err } - if err := u.fillUserInfo(session, u.view.UserByID); err != nil { - return err - } - if err := u.view.PutUserSession(session); err != nil { - return err - } - return nil -} - -func (u *UserSession) fillUserInfo(session *view_model.UserSessionView, getUserByID func(userID, instanceID string) (*view_model.UserView, error)) error { - user, err := getUserByID(session.UserID, session.InstanceID) - if err != nil { - return err - } - session.UserName = user.UserName - session.LoginName = user.PreferredLoginName - session.DisplayName = user.DisplayName - session.AvatarKey = user.AvatarKey - return nil -} - -func (u *UserSession) fillLoginNamesOnOrgUsers(event eventstore.Event) error { - sessions, err := u.view.UserSessionsByOrgID(event.Aggregate().ResourceOwner, event.Aggregate().InstanceID) - if err != nil { - return err - } - if len(sessions) == 0 { - return nil - } - userLoginMustBeDomain, primaryDomain, err := u.loginNameInformation(context.Background(), event.Aggregate().ResourceOwner, event.Aggregate().InstanceID) - if err != nil { - return err - } - if !userLoginMustBeDomain { - return nil - } - for _, session := range sessions { - session.LoginName = session.UserName + "@" + primaryDomain - } - return u.view.PutUserSessions(sessions) -} - -func (u *UserSession) loginNameInformation(ctx context.Context, orgID string, instanceID string) (userLoginMustBeDomain bool, primaryDomain string, err error) { - org, err := u.getOrgByID(ctx, orgID, instanceID) - if err != nil { - return false, "", err - } - primaryDomain, err = org.GetPrimaryDomain() - if err != nil { - return false, "", err - } - if org.DomainPolicy != nil { - return org.DomainPolicy.UserLoginMustBeDomain, primaryDomain, nil - } - policy, err := u.queries.DefaultDomainPolicy(authz.WithInstanceID(ctx, org.InstanceID)) - if err != nil { - return false, "", err - } - return policy.UserLoginMustBeDomain, primaryDomain, nil -} - -func (u *UserSession) getOrgByID(ctx context.Context, orgID, instanceID string) (*org_model.Org, error) { - query, err := org_view.OrgByIDQuery(orgID, instanceID, 0) - if err != nil { - return nil, err - } - - esOrg := &org_es_model.Org{ - ObjectRoot: models.ObjectRoot{ - AggregateID: orgID, - }, - } - events, err := u.es.Filter(ctx, query) - if err != nil { - return nil, err - } - if err = esOrg.AppendEvents(events...); err != nil { - return nil, err - } - - if esOrg.Sequence == 0 { - return nil, zerrors.ThrowNotFound(nil, "EVENT-3m9vs", "Errors.Org.NotFound") - } - return org_es_model.OrgToModel(esOrg), nil + return u.view.PutUserSession(session) } diff --git a/internal/auth/repository/eventsourcing/handler/user_session_test.go b/internal/auth/repository/eventsourcing/handler/user_session_test.go deleted file mode 100644 index cafb3b4718..0000000000 --- a/internal/auth/repository/eventsourcing/handler/user_session_test.go +++ /dev/null @@ -1,94 +0,0 @@ -package handler - -import ( - "testing" - - view_model "github.com/zitadel/zitadel/internal/user/repository/view/model" -) - -// tests the proper working of the cache function -func TestUserSession_fillUserInfo(t *testing.T) { - type args struct { - sessions []*view_model.UserSessionView - } - tests := []struct { - name string - args args - cacheHits map[string]int - }{ - { - name: "one session", - args: args{ - sessions: []*view_model.UserSessionView{ - { - UserID: "user", - InstanceID: "instance", - }, - }, - }, - cacheHits: map[string]int{ - "user-instance": 1, - }, - }, - { - name: "same user", - args: args{ - sessions: []*view_model.UserSessionView{ - { - UserID: "user", - InstanceID: "instance", - }, - { - UserID: "user", - InstanceID: "instance", - }, - }, - }, - cacheHits: map[string]int{ - "user-instance": 2, - }, - }, - { - name: "different users", - args: args{ - sessions: []*view_model.UserSessionView{ - { - UserID: "user", - InstanceID: "instance", - }, - { - UserID: "user2", - InstanceID: "instance", - }, - }, - }, - cacheHits: map[string]int{ - "user-instance": 1, - "user2-instance": 1, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cache := map[string]int{} - getUserByID := func(userID, instanceID string) (*view_model.UserView, error) { - cache[userID+"-"+instanceID]++ - return &view_model.UserView{HumanView: &view_model.HumanView{}}, nil - } - for _, session := range tt.args.sessions { - if err := new(UserSession).fillUserInfo(session, getUserByID); err != nil { - t.Errorf("UserSession.fillUserInfo() unexpected error = %v", err) - } - } - if len(cache) != len(tt.cacheHits) { - t.Errorf("unexpected length of cache hits: want %d, got %d", len(tt.cacheHits), len(cache)) - return - } - for key, count := range tt.cacheHits { - if cache[key] != count { - t.Errorf("unexpected cache hits on %s: want %d, got %d", key, count, cache[key]) - } - } - }) - } -} diff --git a/internal/auth/repository/eventsourcing/view/user_session.go b/internal/auth/repository/eventsourcing/view/user_session.go index e2e4938d62..59d2a967ab 100644 --- a/internal/auth/repository/eventsourcing/view/user_session.go +++ b/internal/auth/repository/eventsourcing/view/user_session.go @@ -15,33 +15,17 @@ const ( ) func (v *View) UserSessionByIDs(agentID, userID, instanceID string) (*model.UserSessionView, error) { - return view.UserSessionByIDs(v.Db, userSessionTable, agentID, userID, instanceID) -} - -func (v *View) UserSessionsByUserID(userID, instanceID string) ([]*model.UserSessionView, error) { - return view.UserSessionsByUserID(v.Db, userSessionTable, userID, instanceID) + return view.UserSessionByIDs(v.client, agentID, userID, instanceID) } func (v *View) UserSessionsByAgentID(agentID, instanceID string) ([]*model.UserSessionView, error) { - return view.UserSessionsByAgentID(v.Db, userSessionTable, agentID, instanceID) -} - -func (v *View) UserSessionsByOrgID(orgID, instanceID string) ([]*model.UserSessionView, error) { - return view.UserSessionsByOrgID(v.Db, userSessionTable, orgID, instanceID) -} - -func (v *View) ActiveUserSessionsCount() (uint64, error) { - return view.ActiveUserSessions(v.Db, userSessionTable) + return view.UserSessionsByAgentID(v.client, agentID, instanceID) } func (v *View) PutUserSession(userSession *model.UserSessionView) error { return view.PutUserSession(v.Db, userSessionTable, userSession) } -func (v *View) PutUserSessions(userSession []*model.UserSessionView) error { - return view.PutUserSessions(v.Db, userSessionTable, userSession...) -} - func (v *View) DeleteUserSessions(userID, instanceID string) error { err := view.DeleteUserSessions(v.Db, userSessionTable, userID, instanceID) if err != nil && !zerrors.IsNotFound(err) { diff --git a/internal/user/repository/view/model/user_session.go b/internal/user/repository/view/model/user_session.go index 3b21e877f2..f4f4b98621 100644 --- a/internal/user/repository/view/model/user_session.go +++ b/internal/user/repository/view/model/user_session.go @@ -23,16 +23,20 @@ const ( ) type UserSessionView struct { - CreationDate time.Time `json:"-" gorm:"column:creation_date"` - ChangeDate time.Time `json:"-" gorm:"column:change_date"` - ResourceOwner string `json:"-" gorm:"column:resource_owner"` - State int32 `json:"-" gorm:"column:state"` - UserAgentID string `json:"userAgentID" gorm:"column:user_agent_id;primary_key"` - UserID string `json:"userID" gorm:"column:user_id;primary_key"` - UserName string `json:"-" gorm:"column:user_name"` - LoginName string `json:"-" gorm:"column:login_name"` - DisplayName string `json:"-" gorm:"column:user_display_name"` - AvatarKey string `json:"-" gorm:"column:avatar_key"` + CreationDate time.Time `json:"-" gorm:"column:creation_date"` + ChangeDate time.Time `json:"-" gorm:"column:change_date"` + ResourceOwner string `json:"-" gorm:"column:resource_owner"` + State int32 `json:"-" gorm:"column:state"` + UserAgentID string `json:"userAgentID" gorm:"column:user_agent_id;primary_key"` + UserID string `json:"userID" gorm:"column:user_id;primary_key"` + // As of https://github.com/zitadel/zitadel/pull/7199 the following 4 attributes + // are not projected in the user session handler anymore + // and are therefore annotated with a `gorm:"-"`. + // They will be read from the corresponding projection directly. + UserName string `json:"-" gorm:"-"` + LoginName string `json:"-" gorm:"-"` + DisplayName string `json:"-" gorm:"-"` + AvatarKey string `json:"-" gorm:"-"` SelectedIDPConfigID string `json:"selectedIDPConfigID" gorm:"column:selected_idp_config_id"` PasswordVerification time.Time `json:"-" gorm:"column:password_verification"` PasswordlessVerification time.Time `json:"-" gorm:"column:passwordless_verification"` @@ -190,14 +194,6 @@ func (v *UserSessionView) AppendEvent(event eventstore.Event) error { case user.UserIDPLinkRemovedType, user.UserIDPLinkCascadeRemovedType: v.ExternalLoginVerification = time.Time{} v.SelectedIDPConfigID = "" - case user.HumanAvatarAddedType: - key, err := avatarKeyFromEvent(event) - if err != nil { - return err - } - v.AvatarKey = key - case user.HumanAvatarRemovedType: - v.AvatarKey = "" } return nil } @@ -208,15 +204,6 @@ func (v *UserSessionView) setSecondFactorVerification(verificationTime time.Time v.State = int32(domain.UserSessionStateActive) } -func avatarKeyFromEvent(event eventstore.Event) (string, error) { - data := make(map[string]string) - if err := event.Unmarshal(&data); err != nil { - logging.Log("EVEN-Sfew2").WithError(err).Error("could not unmarshal event data") - return "", zerrors.ThrowInternal(err, "MODEL-SFw2q", "could not unmarshal event") - } - return data["storeKey"], nil -} - func (v *UserSessionView) EventTypes() []eventstore.EventType { return []eventstore.EventType{ user.UserV1PasswordCheckSucceededType, @@ -250,7 +237,5 @@ func (v *UserSessionView) EventTypes() []eventstore.EventType { user.UserDeactivatedType, user.UserIDPLinkRemovedType, user.UserIDPLinkCascadeRemovedType, - user.HumanAvatarAddedType, - user.HumanAvatarRemovedType, } } diff --git a/internal/user/repository/view/user_session_by_id.sql b/internal/user/repository/view/user_session_by_id.sql new file mode 100644 index 0000000000..334369125b --- /dev/null +++ b/internal/user/repository/view/user_session_by_id.sql @@ -0,0 +1,29 @@ +SELECT s.creation_date, + s.change_date, + s.resource_owner, + s.state, + s.user_agent_id, + s.user_id, + u.username, + l.login_name, + h.display_name, + h.avatar_key, + s.selected_idp_config_id, + s.password_verification, + s.passwordless_verification, + s.external_login_verification, + s.second_factor_verification, + s.second_factor_verification_type, + s.multi_factor_verification, + s.multi_factor_verification_type, + s.sequence, + s.instance_id +FROM auth.user_sessions s + LEFT JOIN projections.users10 u ON s.user_id = u.id AND s.instance_id = u.instance_id + LEFT JOIN projections.users10_humans h ON s.user_id = h.user_id AND s.instance_id = h.instance_id + LEFT JOIN projections.login_names3 l ON s.user_id = l.user_id AND s.instance_id = l.instance_id AND l.is_primary = true +WHERE (s.user_agent_id = $1) + AND (s.user_id = $2) + AND (s.instance_id = $3) +LIMIT 1 +; \ No newline at end of file diff --git a/internal/user/repository/view/user_session_view.go b/internal/user/repository/view/user_session_view.go index 487857c1ba..3afd3c16b2 100644 --- a/internal/user/repository/view/user_session_view.go +++ b/internal/user/repository/view/user_session_view.go @@ -1,123 +1,56 @@ package view import ( + "database/sql" + _ "embed" + "errors" + "github.com/jinzhu/gorm" - "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/database" usr_model "github.com/zitadel/zitadel/internal/user/model" "github.com/zitadel/zitadel/internal/user/repository/view/model" "github.com/zitadel/zitadel/internal/view/repository" "github.com/zitadel/zitadel/internal/zerrors" ) -func UserSessionByIDs(db *gorm.DB, table, agentID, userID, instanceID string) (*model.UserSessionView, error) { - userSession := new(model.UserSessionView) - userAgentQuery := model.UserSessionSearchQuery{ - Key: usr_model.UserSessionSearchKeyUserAgentID, - Method: domain.SearchMethodEquals, - Value: agentID, - } - userQuery := model.UserSessionSearchQuery{ - Key: usr_model.UserSessionSearchKeyUserID, - Method: domain.SearchMethodEquals, - Value: userID, - } - instanceIDQuery := &model.UserSessionSearchQuery{ - Key: usr_model.UserSessionSearchKeyInstanceID, - Method: domain.SearchMethodEquals, - Value: instanceID, - } - query := repository.PrepareGetByQuery(table, userAgentQuery, userQuery, instanceIDQuery) - err := query(db, userSession) - if zerrors.IsNotFound(err) { - return nil, zerrors.ThrowNotFound(nil, "VIEW-NGBs1", "Errors.UserSession.NotFound") - } +//go:embed user_session_by_id.sql +var userSessionByIDQuery string + +//go:embed user_sessions_by_user_agent.sql +var userSessionsByUserAgentQuery string + +func UserSessionByIDs(db *database.DB, agentID, userID, instanceID string) (userSession *model.UserSessionView, err error) { + err = db.QueryRow( + func(row *sql.Row) error { + userSession, err = scanUserSession(row) + return err + }, + userSessionByIDQuery, + agentID, + userID, + instanceID, + ) return userSession, err } - -func UserSessionsByUserID(db *gorm.DB, table, userID, instanceID string) ([]*model.UserSessionView, error) { - userSessions := make([]*model.UserSessionView, 0) - userAgentQuery := &usr_model.UserSessionSearchQuery{ - Key: usr_model.UserSessionSearchKeyUserID, - Method: domain.SearchMethodEquals, - Value: userID, - } - instanceIDQuery := &usr_model.UserSessionSearchQuery{ - Key: usr_model.UserSessionSearchKeyInstanceID, - Method: domain.SearchMethodEquals, - Value: instanceID, - } - query := repository.PrepareSearchQuery(table, model.UserSessionSearchRequest{ - Queries: []*usr_model.UserSessionSearchQuery{userAgentQuery, instanceIDQuery}, - }) - _, err := query(db, &userSessions) +func UserSessionsByAgentID(db *database.DB, agentID, instanceID string) (userSessions []*model.UserSessionView, err error) { + err = db.Query( + func(rows *sql.Rows) error { + userSessions, err = scanUserSessions(rows) + return err + }, + userSessionsByUserAgentQuery, + agentID, + instanceID, + ) return userSessions, err } -func UserSessionsByAgentID(db *gorm.DB, table, agentID, instanceID string) ([]*model.UserSessionView, error) { - userSessions := make([]*model.UserSessionView, 0) - userAgentQuery := &usr_model.UserSessionSearchQuery{ - Key: usr_model.UserSessionSearchKeyUserAgentID, - Method: domain.SearchMethodEquals, - Value: agentID, - } - instanceIDQuery := &usr_model.UserSessionSearchQuery{ - Key: usr_model.UserSessionSearchKeyInstanceID, - Method: domain.SearchMethodEquals, - Value: instanceID, - } - query := repository.PrepareSearchQuery(table, model.UserSessionSearchRequest{ - Queries: []*usr_model.UserSessionSearchQuery{userAgentQuery, instanceIDQuery}, - }) - _, err := query(db, &userSessions) - return userSessions, err -} - -func UserSessionsByOrgID(db *gorm.DB, table, orgID, instanceID string) ([]*model.UserSessionView, error) { - userSessions := make([]*model.UserSessionView, 0) - userAgentQuery := &usr_model.UserSessionSearchQuery{ - Key: usr_model.UserSessionSearchKeyResourceOwner, - Method: domain.SearchMethodEquals, - Value: orgID, - } - instanceIDQuery := &usr_model.UserSessionSearchQuery{ - Key: usr_model.UserSessionSearchKeyInstanceID, - Method: domain.SearchMethodEquals, - Value: instanceID, - } - query := repository.PrepareSearchQuery(table, model.UserSessionSearchRequest{ - Queries: []*usr_model.UserSessionSearchQuery{userAgentQuery, instanceIDQuery}, - }) - _, err := query(db, &userSessions) - return userSessions, err -} - -func ActiveUserSessions(db *gorm.DB, table string) (uint64, error) { - activeQuery := &usr_model.UserSessionSearchQuery{ - Key: usr_model.UserSessionSearchKeyState, - Method: domain.SearchMethodEquals, - Value: domain.UserSessionStateActive, - } - query := repository.PrepareSearchQuery(table, model.UserSessionSearchRequest{ - Queries: []*usr_model.UserSessionSearchQuery{activeQuery}, - }) - return query(db, nil) -} - func PutUserSession(db *gorm.DB, table string, session *model.UserSessionView) error { save := repository.PrepareSave(table) return save(db, session) } -func PutUserSessions(db *gorm.DB, table string, sessions ...*model.UserSessionView) error { - save := repository.PrepareBulkSave(table) - s := make([]interface{}, len(sessions)) - for i, session := range sessions { - s[i] = session - } - return save(db, s...) -} - func DeleteUserSessions(db *gorm.DB, table, userID, instanceID string) error { delete := repository.PrepareDeleteByKeys(table, repository.Key{Key: model.UserSessionSearchKey(usr_model.UserSessionSearchKeyUserID), Value: userID}, @@ -141,3 +74,81 @@ func DeleteOrgUserSessions(db *gorm.DB, table, instanceID, orgID string) error { ) return delete(db) } + +func scanUserSession(row *sql.Row) (*model.UserSessionView, error) { + session := new(model.UserSessionView) + var userName, loginName, displayName, avatarKey sql.NullString + err := row.Scan( + &session.CreationDate, + &session.ChangeDate, + &session.ResourceOwner, + &session.State, + &session.UserAgentID, + &session.UserID, + &userName, + &loginName, + &displayName, + &avatarKey, + &session.SelectedIDPConfigID, + &session.PasswordVerification, + &session.PasswordlessVerification, + &session.ExternalLoginVerification, + &session.SecondFactorVerification, + &session.SecondFactorVerificationType, + &session.MultiFactorVerification, + &session.MultiFactorVerificationType, + &session.Sequence, + &session.InstanceID, + ) + if errors.Is(err, sql.ErrNoRows) { + return nil, zerrors.ThrowNotFound(nil, "VIEW-NGBs1", "Errors.UserSession.NotFound") + } + session.UserName = userName.String + session.LoginName = loginName.String + session.DisplayName = displayName.String + session.AvatarKey = avatarKey.String + return session, err +} + +func scanUserSessions(rows *sql.Rows) ([]*model.UserSessionView, error) { + sessions := make([]*model.UserSessionView, 0) + for rows.Next() { + session := new(model.UserSessionView) + var userName, loginName, displayName, avatarKey sql.NullString + err := rows.Scan( + &session.CreationDate, + &session.ChangeDate, + &session.ResourceOwner, + &session.State, + &session.UserAgentID, + &session.UserID, + &userName, + &loginName, + &displayName, + &avatarKey, + &session.SelectedIDPConfigID, + &session.PasswordVerification, + &session.PasswordlessVerification, + &session.ExternalLoginVerification, + &session.SecondFactorVerification, + &session.SecondFactorVerificationType, + &session.MultiFactorVerification, + &session.MultiFactorVerificationType, + &session.Sequence, + &session.InstanceID, + ) + if err != nil { + return nil, err + } + session.UserName = userName.String + session.LoginName = loginName.String + session.DisplayName = displayName.String + session.AvatarKey = avatarKey.String + sessions = append(sessions, session) + } + + if err := rows.Close(); err != nil { + return nil, zerrors.ThrowInternal(err, "VIEW-FSF3g", "Errors.Query.CloseRows") + } + return sessions, nil +} diff --git a/internal/user/repository/view/user_sessions_by_user_agent.sql b/internal/user/repository/view/user_sessions_by_user_agent.sql new file mode 100644 index 0000000000..f3fc06b377 --- /dev/null +++ b/internal/user/repository/view/user_sessions_by_user_agent.sql @@ -0,0 +1,27 @@ +SELECT s.creation_date, + s.change_date, + s.resource_owner, + s.state, + s.user_agent_id, + s.user_id, + u.username, + l.login_name, + h.display_name, + h.avatar_key, + s.selected_idp_config_id, + s.password_verification, + s.passwordless_verification, + s.external_login_verification, + s.second_factor_verification, + s.second_factor_verification_type, + s.multi_factor_verification, + s.multi_factor_verification_type, + s.sequence, + s.instance_id +FROM auth.user_sessions s + LEFT JOIN projections.users10 u ON s.user_id = u.id AND s.instance_id = u.instance_id + LEFT JOIN projections.users10_humans h ON s.user_id = h.user_id AND s.instance_id = h.instance_id + LEFT JOIN projections.login_names3 l ON s.user_id = l.user_id AND s.instance_id = l.instance_id AND l.is_primary = true +WHERE (s.user_agent_id = $1) + AND (s.instance_id = $2) +; \ No newline at end of file