mirror of
https://github.com/zitadel/zitadel.git
synced 2025-01-12 23:43:41 +00:00
fix: update user sessions after avatar or primary domain change (#3768)
This commit is contained in:
parent
0baaaf8a05
commit
da1f74fde0
@ -36,7 +36,7 @@ func Register(configs Configs, bulkLimit, errorCount uint64, view *view.View, es
|
||||
newUser(
|
||||
handler{view, bulkLimit, configs.cycleDuration("User"), errorCount, es}, queries),
|
||||
newUserSession(
|
||||
handler{view, bulkLimit, configs.cycleDuration("UserSession"), errorCount, es}),
|
||||
handler{view, bulkLimit, configs.cycleDuration("UserSession"), errorCount, es}, queries),
|
||||
newToken(
|
||||
handler{view, bulkLimit, configs.cycleDuration("Token"), errorCount, es}),
|
||||
newIDPConfig(
|
||||
|
@ -1,6 +1,8 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
@ -9,7 +11,13 @@ import (
|
||||
v1 "github.com/zitadel/zitadel/internal/eventstore/v1"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/query"
|
||||
es_sdk "github.com/zitadel/zitadel/internal/eventstore/v1/sdk"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/spooler"
|
||||
org_model "github.com/zitadel/zitadel/internal/org/model"
|
||||
org_es_model "github.com/zitadel/zitadel/internal/org/repository/eventsourcing/model"
|
||||
"github.com/zitadel/zitadel/internal/org/repository/view"
|
||||
query2 "github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/repository/org"
|
||||
"github.com/zitadel/zitadel/internal/repository/user"
|
||||
view_model "github.com/zitadel/zitadel/internal/user/repository/view/model"
|
||||
)
|
||||
@ -21,13 +29,13 @@ const (
|
||||
type UserSession struct {
|
||||
handler
|
||||
subscription *v1.Subscription
|
||||
queries *query2.Queries
|
||||
}
|
||||
|
||||
func newUserSession(
|
||||
handler handler,
|
||||
) *UserSession {
|
||||
func newUserSession(handler handler, queries *query2.Queries) *UserSession {
|
||||
h := &UserSession{
|
||||
handler: handler,
|
||||
queries: queries,
|
||||
}
|
||||
|
||||
h.subscribe()
|
||||
@ -53,7 +61,7 @@ func (u *UserSession) Subscription() *v1.Subscription {
|
||||
}
|
||||
|
||||
func (_ *UserSession) AggregateTypes() []models.AggregateType {
|
||||
return []models.AggregateType{user.AggregateType}
|
||||
return []models.AggregateType{user.AggregateType, org.AggregateType}
|
||||
}
|
||||
|
||||
func (u *UserSession) CurrentSequence(instanceID string) (uint64, error) {
|
||||
@ -154,11 +162,13 @@ func (u *UserSession) Reduce(event *models.Event) (err error) {
|
||||
if err := session.AppendEvent(event); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := u.fillUserInfo(session, event.AggregateID); err != nil {
|
||||
if err := u.fillUserInfo(session); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return u.view.PutUserSessions(sessions, event)
|
||||
case org.OrgDomainPrimarySetEventType:
|
||||
return u.fillLoginNamesOnOrgUsers(event)
|
||||
case user.UserRemovedType:
|
||||
return u.view.DeleteUserSessions(event.AggregateID, event.InstanceID, event)
|
||||
default:
|
||||
@ -179,14 +189,14 @@ func (u *UserSession) updateSession(session *view_model.UserSessionView, event *
|
||||
if err := session.AppendEvent(event); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := u.fillUserInfo(session, event.AggregateID); err != nil {
|
||||
if err := u.fillUserInfo(session); err != nil {
|
||||
return err
|
||||
}
|
||||
return u.view.PutUserSession(session, event)
|
||||
}
|
||||
|
||||
func (u *UserSession) fillUserInfo(session *view_model.UserSessionView, id string) error {
|
||||
user, err := u.view.UserByID(id, session.InstanceID)
|
||||
func (u *UserSession) fillUserInfo(session *view_model.UserSessionView) error {
|
||||
user, err := u.view.UserByID(session.UserID, session.InstanceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -196,3 +206,61 @@ func (u *UserSession) fillUserInfo(session *view_model.UserSessionView, id strin
|
||||
session.AvatarKey = user.AvatarKey
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UserSession) fillLoginNamesOnOrgUsers(event *models.Event) error {
|
||||
sessions, err := u.view.UserSessionsByOrgID(event.ResourceOwner, event.InstanceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(sessions) == 0 {
|
||||
return u.view.ProcessedUserSessionSequence(event)
|
||||
}
|
||||
userLoginMustBeDomain, primaryDomain, err := u.loginNameInformation(context.Background(), event.ResourceOwner, event.InstanceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !userLoginMustBeDomain {
|
||||
return nil
|
||||
}
|
||||
for _, session := range sessions {
|
||||
session.LoginName = session.UserName + "@" + primaryDomain
|
||||
}
|
||||
return u.view.PutUserSessions(sessions, event)
|
||||
}
|
||||
|
||||
func (u *UserSession) loginNameInformation(ctx context.Context, orgID string, instanceID string) (userLoginMustBeDomain bool, primaryDomain string, err error) {
|
||||
org, err := u.getOrgByID(ctx, orgID, instanceID)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
if org.DomainPolicy == nil {
|
||||
policy, err := u.queries.DefaultDomainPolicy(withInstanceID(ctx, org.InstanceID))
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
userLoginMustBeDomain = policy.UserLoginMustBeDomain
|
||||
}
|
||||
return userLoginMustBeDomain, org.GetPrimaryDomain().Domain, nil
|
||||
}
|
||||
|
||||
func (u *UserSession) getOrgByID(ctx context.Context, orgID, instanceID string) (*org_model.Org, error) {
|
||||
query, err := view.OrgByIDQuery(orgID, instanceID, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
esOrg := &org_es_model.Org{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: orgID,
|
||||
},
|
||||
}
|
||||
err = es_sdk.Filter(ctx, u.Eventstore().FilterEvents, esOrg.AppendEvents, query)
|
||||
if err != nil && !errors.IsNotFound(err) {
|
||||
return nil, err
|
||||
}
|
||||
if esOrg.Sequence == 0 {
|
||||
return nil, errors.ThrowNotFound(nil, "EVENT-3m9vs", "Errors.Org.NotFound")
|
||||
}
|
||||
|
||||
return org_es_model.OrgToModel(esOrg), nil
|
||||
}
|
||||
|
@ -24,6 +24,10 @@ func (v *View) UserSessionsByAgentID(agentID, instanceID string) ([]*model.UserS
|
||||
return view.UserSessionsByAgentID(v.Db, userSessionTable, agentID, instanceID)
|
||||
}
|
||||
|
||||
func (v *View) UserSessionsByOrgID(orgID, instanceID string) ([]*model.UserSessionView, error) {
|
||||
return view.UserSessionsByOrgID(v.Db, userSessionTable, orgID, instanceID)
|
||||
}
|
||||
|
||||
func (v *View) ActiveUserSessionsCount() (uint64, error) {
|
||||
return view.ActiveUserSessions(v.Db, userSessionTable)
|
||||
}
|
||||
|
@ -171,6 +171,14 @@ func (v *UserSessionView) AppendEvent(event *models.Event) error {
|
||||
case user.UserIDPLinkRemovedType, user.UserIDPLinkCascadeRemovedType:
|
||||
v.ExternalLoginVerification = time.Time{}
|
||||
v.SelectedIDPConfigID = ""
|
||||
case user.HumanAvatarAddedType:
|
||||
key, err := avatarKeyFromEvent(event)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v.AvatarKey = key
|
||||
case user.HumanAvatarRemovedType:
|
||||
v.AvatarKey = ""
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -180,3 +188,12 @@ func (v *UserSessionView) setSecondFactorVerification(verificationTime time.Time
|
||||
v.SecondFactorVerificationType = int32(mfaType)
|
||||
v.State = int32(domain.UserSessionStateActive)
|
||||
}
|
||||
|
||||
func avatarKeyFromEvent(event *models.Event) (string, error) {
|
||||
data := make(map[string]string)
|
||||
if err := json.Unmarshal(event.Data, &data); err != nil {
|
||||
logging.Log("EVEN-Sfew2").WithError(err).Error("could not unmarshal event data")
|
||||
return "", caos_errs.ThrowInternal(err, "MODEL-SFw2q", "could not unmarshal event")
|
||||
}
|
||||
return data["storeKey"], nil
|
||||
}
|
||||
|
@ -73,6 +73,25 @@ func UserSessionsByAgentID(db *gorm.DB, table, agentID, instanceID string) ([]*m
|
||||
return userSessions, err
|
||||
}
|
||||
|
||||
func UserSessionsByOrgID(db *gorm.DB, table, orgID, instanceID string) ([]*model.UserSessionView, error) {
|
||||
userSessions := make([]*model.UserSessionView, 0)
|
||||
userAgentQuery := &usr_model.UserSessionSearchQuery{
|
||||
Key: usr_model.UserSessionSearchKeyResourceOwner,
|
||||
Method: domain.SearchMethodEquals,
|
||||
Value: orgID,
|
||||
}
|
||||
instanceIDQuery := &usr_model.UserSessionSearchQuery{
|
||||
Key: usr_model.UserSessionSearchKeyInstanceID,
|
||||
Method: domain.SearchMethodEquals,
|
||||
Value: instanceID,
|
||||
}
|
||||
query := repository.PrepareSearchQuery(table, model.UserSessionSearchRequest{
|
||||
Queries: []*usr_model.UserSessionSearchQuery{userAgentQuery, instanceIDQuery},
|
||||
})
|
||||
_, err := query(db, &userSessions)
|
||||
return userSessions, err
|
||||
}
|
||||
|
||||
func ActiveUserSessions(db *gorm.DB, table string) (uint64, error) {
|
||||
activeQuery := &usr_model.UserSessionSearchQuery{
|
||||
Key: usr_model.UserSessionSearchKeyState,
|
||||
|
@ -39,7 +39,7 @@ func PrepareGetByQuery(table string, queries ...SearchQuery) func(db *gorm.DB, r
|
||||
}
|
||||
}
|
||||
|
||||
err := query.Take(res).Error
|
||||
err := query.Debug().Take(res).Error
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user