diff --git a/internal/auth/repository/eventsourcing/handler/handler.go b/internal/auth/repository/eventsourcing/handler/handler.go index 55a514867f..fbefb081fb 100644 --- a/internal/auth/repository/eventsourcing/handler/handler.go +++ b/internal/auth/repository/eventsourcing/handler/handler.go @@ -36,7 +36,7 @@ func Register(configs Configs, bulkLimit, errorCount uint64, view *view.View, es newUser( handler{view, bulkLimit, configs.cycleDuration("User"), errorCount, es}, queries), newUserSession( - handler{view, bulkLimit, configs.cycleDuration("UserSession"), errorCount, es}), + handler{view, bulkLimit, configs.cycleDuration("UserSession"), errorCount, es}, queries), newToken( handler{view, bulkLimit, configs.cycleDuration("Token"), errorCount, es}), newIDPConfig( diff --git a/internal/auth/repository/eventsourcing/handler/user_session.go b/internal/auth/repository/eventsourcing/handler/user_session.go index c12e182c81..6e285bee39 100644 --- a/internal/auth/repository/eventsourcing/handler/user_session.go +++ b/internal/auth/repository/eventsourcing/handler/user_session.go @@ -1,6 +1,8 @@ package handler import ( + "context" + "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/domain" @@ -9,7 +11,13 @@ import ( v1 "github.com/zitadel/zitadel/internal/eventstore/v1" "github.com/zitadel/zitadel/internal/eventstore/v1/models" "github.com/zitadel/zitadel/internal/eventstore/v1/query" + es_sdk "github.com/zitadel/zitadel/internal/eventstore/v1/sdk" "github.com/zitadel/zitadel/internal/eventstore/v1/spooler" + org_model "github.com/zitadel/zitadel/internal/org/model" + org_es_model "github.com/zitadel/zitadel/internal/org/repository/eventsourcing/model" + "github.com/zitadel/zitadel/internal/org/repository/view" + query2 "github.com/zitadel/zitadel/internal/query" + "github.com/zitadel/zitadel/internal/repository/org" "github.com/zitadel/zitadel/internal/repository/user" view_model "github.com/zitadel/zitadel/internal/user/repository/view/model" ) @@ -21,13 +29,13 @@ const ( type UserSession struct { handler subscription *v1.Subscription + queries *query2.Queries } -func newUserSession( - handler handler, -) *UserSession { +func newUserSession(handler handler, queries *query2.Queries) *UserSession { h := &UserSession{ handler: handler, + queries: queries, } h.subscribe() @@ -53,7 +61,7 @@ func (u *UserSession) Subscription() *v1.Subscription { } func (_ *UserSession) AggregateTypes() []models.AggregateType { - return []models.AggregateType{user.AggregateType} + return []models.AggregateType{user.AggregateType, org.AggregateType} } func (u *UserSession) CurrentSequence(instanceID string) (uint64, error) { @@ -154,11 +162,13 @@ func (u *UserSession) Reduce(event *models.Event) (err error) { if err := session.AppendEvent(event); err != nil { return err } - if err := u.fillUserInfo(session, event.AggregateID); err != nil { + if err := u.fillUserInfo(session); err != nil { return err } } return u.view.PutUserSessions(sessions, event) + case org.OrgDomainPrimarySetEventType: + return u.fillLoginNamesOnOrgUsers(event) case user.UserRemovedType: return u.view.DeleteUserSessions(event.AggregateID, event.InstanceID, event) default: @@ -179,14 +189,14 @@ func (u *UserSession) updateSession(session *view_model.UserSessionView, event * if err := session.AppendEvent(event); err != nil { return err } - if err := u.fillUserInfo(session, event.AggregateID); err != nil { + if err := u.fillUserInfo(session); err != nil { return err } return u.view.PutUserSession(session, event) } -func (u *UserSession) fillUserInfo(session *view_model.UserSessionView, id string) error { - user, err := u.view.UserByID(id, session.InstanceID) +func (u *UserSession) fillUserInfo(session *view_model.UserSessionView) error { + user, err := u.view.UserByID(session.UserID, session.InstanceID) if err != nil { return err } @@ -196,3 +206,61 @@ func (u *UserSession) fillUserInfo(session *view_model.UserSessionView, id strin session.AvatarKey = user.AvatarKey return nil } + +func (u *UserSession) fillLoginNamesOnOrgUsers(event *models.Event) error { + sessions, err := u.view.UserSessionsByOrgID(event.ResourceOwner, event.InstanceID) + if err != nil { + return err + } + if len(sessions) == 0 { + return u.view.ProcessedUserSessionSequence(event) + } + userLoginMustBeDomain, primaryDomain, err := u.loginNameInformation(context.Background(), event.ResourceOwner, event.InstanceID) + if err != nil { + return err + } + if !userLoginMustBeDomain { + return nil + } + for _, session := range sessions { + session.LoginName = session.UserName + "@" + primaryDomain + } + return u.view.PutUserSessions(sessions, event) +} + +func (u *UserSession) loginNameInformation(ctx context.Context, orgID string, instanceID string) (userLoginMustBeDomain bool, primaryDomain string, err error) { + org, err := u.getOrgByID(ctx, orgID, instanceID) + if err != nil { + return false, "", err + } + if org.DomainPolicy == nil { + policy, err := u.queries.DefaultDomainPolicy(withInstanceID(ctx, org.InstanceID)) + if err != nil { + return false, "", err + } + userLoginMustBeDomain = policy.UserLoginMustBeDomain + } + return userLoginMustBeDomain, org.GetPrimaryDomain().Domain, nil +} + +func (u *UserSession) getOrgByID(ctx context.Context, orgID, instanceID string) (*org_model.Org, error) { + query, err := view.OrgByIDQuery(orgID, instanceID, 0) + if err != nil { + return nil, err + } + + esOrg := &org_es_model.Org{ + ObjectRoot: models.ObjectRoot{ + AggregateID: orgID, + }, + } + err = es_sdk.Filter(ctx, u.Eventstore().FilterEvents, esOrg.AppendEvents, query) + if err != nil && !errors.IsNotFound(err) { + return nil, err + } + if esOrg.Sequence == 0 { + return nil, errors.ThrowNotFound(nil, "EVENT-3m9vs", "Errors.Org.NotFound") + } + + return org_es_model.OrgToModel(esOrg), nil +} diff --git a/internal/auth/repository/eventsourcing/view/user_session.go b/internal/auth/repository/eventsourcing/view/user_session.go index 3aa2229113..27c4b3023f 100644 --- a/internal/auth/repository/eventsourcing/view/user_session.go +++ b/internal/auth/repository/eventsourcing/view/user_session.go @@ -24,6 +24,10 @@ func (v *View) UserSessionsByAgentID(agentID, instanceID string) ([]*model.UserS return view.UserSessionsByAgentID(v.Db, userSessionTable, agentID, instanceID) } +func (v *View) UserSessionsByOrgID(orgID, instanceID string) ([]*model.UserSessionView, error) { + return view.UserSessionsByOrgID(v.Db, userSessionTable, orgID, instanceID) +} + func (v *View) ActiveUserSessionsCount() (uint64, error) { return view.ActiveUserSessions(v.Db, userSessionTable) } diff --git a/internal/user/repository/view/model/user_session.go b/internal/user/repository/view/model/user_session.go index 22989f28e0..9896691b0e 100644 --- a/internal/user/repository/view/model/user_session.go +++ b/internal/user/repository/view/model/user_session.go @@ -171,6 +171,14 @@ func (v *UserSessionView) AppendEvent(event *models.Event) error { case user.UserIDPLinkRemovedType, user.UserIDPLinkCascadeRemovedType: v.ExternalLoginVerification = time.Time{} v.SelectedIDPConfigID = "" + case user.HumanAvatarAddedType: + key, err := avatarKeyFromEvent(event) + if err != nil { + return err + } + v.AvatarKey = key + case user.HumanAvatarRemovedType: + v.AvatarKey = "" } return nil } @@ -180,3 +188,12 @@ func (v *UserSessionView) setSecondFactorVerification(verificationTime time.Time v.SecondFactorVerificationType = int32(mfaType) v.State = int32(domain.UserSessionStateActive) } + +func avatarKeyFromEvent(event *models.Event) (string, error) { + data := make(map[string]string) + if err := json.Unmarshal(event.Data, &data); err != nil { + logging.Log("EVEN-Sfew2").WithError(err).Error("could not unmarshal event data") + return "", caos_errs.ThrowInternal(err, "MODEL-SFw2q", "could not unmarshal event") + } + return data["storeKey"], nil +} diff --git a/internal/user/repository/view/user_session_view.go b/internal/user/repository/view/user_session_view.go index 1d5d8efda6..4784c1b45d 100644 --- a/internal/user/repository/view/user_session_view.go +++ b/internal/user/repository/view/user_session_view.go @@ -73,6 +73,25 @@ func UserSessionsByAgentID(db *gorm.DB, table, agentID, instanceID string) ([]*m return userSessions, err } +func UserSessionsByOrgID(db *gorm.DB, table, orgID, instanceID string) ([]*model.UserSessionView, error) { + userSessions := make([]*model.UserSessionView, 0) + userAgentQuery := &usr_model.UserSessionSearchQuery{ + Key: usr_model.UserSessionSearchKeyResourceOwner, + Method: domain.SearchMethodEquals, + Value: orgID, + } + instanceIDQuery := &usr_model.UserSessionSearchQuery{ + Key: usr_model.UserSessionSearchKeyInstanceID, + Method: domain.SearchMethodEquals, + Value: instanceID, + } + query := repository.PrepareSearchQuery(table, model.UserSessionSearchRequest{ + Queries: []*usr_model.UserSessionSearchQuery{userAgentQuery, instanceIDQuery}, + }) + _, err := query(db, &userSessions) + return userSessions, err +} + func ActiveUserSessions(db *gorm.DB, table string) (uint64, error) { activeQuery := &usr_model.UserSessionSearchQuery{ Key: usr_model.UserSessionSearchKeyState, diff --git a/internal/view/repository/requests.go b/internal/view/repository/requests.go index 29f28b25a9..4dc22dea76 100644 --- a/internal/view/repository/requests.go +++ b/internal/view/repository/requests.go @@ -39,7 +39,7 @@ func PrepareGetByQuery(table string, queries ...SearchQuery) func(db *gorm.DB, r } } - err := query.Take(res).Error + err := query.Debug().Take(res).Error if err == nil { return nil }