mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 21:47:32 +00:00
fix: update user sessions after avatar or primary domain change (#3768)
This commit is contained in:
@@ -36,7 +36,7 @@ func Register(configs Configs, bulkLimit, errorCount uint64, view *view.View, es
|
||||
newUser(
|
||||
handler{view, bulkLimit, configs.cycleDuration("User"), errorCount, es}, queries),
|
||||
newUserSession(
|
||||
handler{view, bulkLimit, configs.cycleDuration("UserSession"), errorCount, es}),
|
||||
handler{view, bulkLimit, configs.cycleDuration("UserSession"), errorCount, es}, queries),
|
||||
newToken(
|
||||
handler{view, bulkLimit, configs.cycleDuration("Token"), errorCount, es}),
|
||||
newIDPConfig(
|
||||
|
@@ -1,6 +1,8 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
@@ -9,7 +11,13 @@ import (
|
||||
v1 "github.com/zitadel/zitadel/internal/eventstore/v1"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/query"
|
||||
es_sdk "github.com/zitadel/zitadel/internal/eventstore/v1/sdk"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/spooler"
|
||||
org_model "github.com/zitadel/zitadel/internal/org/model"
|
||||
org_es_model "github.com/zitadel/zitadel/internal/org/repository/eventsourcing/model"
|
||||
"github.com/zitadel/zitadel/internal/org/repository/view"
|
||||
query2 "github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/repository/org"
|
||||
"github.com/zitadel/zitadel/internal/repository/user"
|
||||
view_model "github.com/zitadel/zitadel/internal/user/repository/view/model"
|
||||
)
|
||||
@@ -21,13 +29,13 @@ const (
|
||||
type UserSession struct {
|
||||
handler
|
||||
subscription *v1.Subscription
|
||||
queries *query2.Queries
|
||||
}
|
||||
|
||||
func newUserSession(
|
||||
handler handler,
|
||||
) *UserSession {
|
||||
func newUserSession(handler handler, queries *query2.Queries) *UserSession {
|
||||
h := &UserSession{
|
||||
handler: handler,
|
||||
queries: queries,
|
||||
}
|
||||
|
||||
h.subscribe()
|
||||
@@ -53,7 +61,7 @@ func (u *UserSession) Subscription() *v1.Subscription {
|
||||
}
|
||||
|
||||
func (_ *UserSession) AggregateTypes() []models.AggregateType {
|
||||
return []models.AggregateType{user.AggregateType}
|
||||
return []models.AggregateType{user.AggregateType, org.AggregateType}
|
||||
}
|
||||
|
||||
func (u *UserSession) CurrentSequence(instanceID string) (uint64, error) {
|
||||
@@ -154,11 +162,13 @@ func (u *UserSession) Reduce(event *models.Event) (err error) {
|
||||
if err := session.AppendEvent(event); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := u.fillUserInfo(session, event.AggregateID); err != nil {
|
||||
if err := u.fillUserInfo(session); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return u.view.PutUserSessions(sessions, event)
|
||||
case org.OrgDomainPrimarySetEventType:
|
||||
return u.fillLoginNamesOnOrgUsers(event)
|
||||
case user.UserRemovedType:
|
||||
return u.view.DeleteUserSessions(event.AggregateID, event.InstanceID, event)
|
||||
default:
|
||||
@@ -179,14 +189,14 @@ func (u *UserSession) updateSession(session *view_model.UserSessionView, event *
|
||||
if err := session.AppendEvent(event); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := u.fillUserInfo(session, event.AggregateID); err != nil {
|
||||
if err := u.fillUserInfo(session); err != nil {
|
||||
return err
|
||||
}
|
||||
return u.view.PutUserSession(session, event)
|
||||
}
|
||||
|
||||
func (u *UserSession) fillUserInfo(session *view_model.UserSessionView, id string) error {
|
||||
user, err := u.view.UserByID(id, session.InstanceID)
|
||||
func (u *UserSession) fillUserInfo(session *view_model.UserSessionView) error {
|
||||
user, err := u.view.UserByID(session.UserID, session.InstanceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -196,3 +206,61 @@ func (u *UserSession) fillUserInfo(session *view_model.UserSessionView, id strin
|
||||
session.AvatarKey = user.AvatarKey
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UserSession) fillLoginNamesOnOrgUsers(event *models.Event) error {
|
||||
sessions, err := u.view.UserSessionsByOrgID(event.ResourceOwner, event.InstanceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(sessions) == 0 {
|
||||
return u.view.ProcessedUserSessionSequence(event)
|
||||
}
|
||||
userLoginMustBeDomain, primaryDomain, err := u.loginNameInformation(context.Background(), event.ResourceOwner, event.InstanceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !userLoginMustBeDomain {
|
||||
return nil
|
||||
}
|
||||
for _, session := range sessions {
|
||||
session.LoginName = session.UserName + "@" + primaryDomain
|
||||
}
|
||||
return u.view.PutUserSessions(sessions, event)
|
||||
}
|
||||
|
||||
func (u *UserSession) loginNameInformation(ctx context.Context, orgID string, instanceID string) (userLoginMustBeDomain bool, primaryDomain string, err error) {
|
||||
org, err := u.getOrgByID(ctx, orgID, instanceID)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
if org.DomainPolicy == nil {
|
||||
policy, err := u.queries.DefaultDomainPolicy(withInstanceID(ctx, org.InstanceID))
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
userLoginMustBeDomain = policy.UserLoginMustBeDomain
|
||||
}
|
||||
return userLoginMustBeDomain, org.GetPrimaryDomain().Domain, nil
|
||||
}
|
||||
|
||||
func (u *UserSession) getOrgByID(ctx context.Context, orgID, instanceID string) (*org_model.Org, error) {
|
||||
query, err := view.OrgByIDQuery(orgID, instanceID, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
esOrg := &org_es_model.Org{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: orgID,
|
||||
},
|
||||
}
|
||||
err = es_sdk.Filter(ctx, u.Eventstore().FilterEvents, esOrg.AppendEvents, query)
|
||||
if err != nil && !errors.IsNotFound(err) {
|
||||
return nil, err
|
||||
}
|
||||
if esOrg.Sequence == 0 {
|
||||
return nil, errors.ThrowNotFound(nil, "EVENT-3m9vs", "Errors.Org.NotFound")
|
||||
}
|
||||
|
||||
return org_es_model.OrgToModel(esOrg), nil
|
||||
}
|
||||
|
@@ -24,6 +24,10 @@ func (v *View) UserSessionsByAgentID(agentID, instanceID string) ([]*model.UserS
|
||||
return view.UserSessionsByAgentID(v.Db, userSessionTable, agentID, instanceID)
|
||||
}
|
||||
|
||||
func (v *View) UserSessionsByOrgID(orgID, instanceID string) ([]*model.UserSessionView, error) {
|
||||
return view.UserSessionsByOrgID(v.Db, userSessionTable, orgID, instanceID)
|
||||
}
|
||||
|
||||
func (v *View) ActiveUserSessionsCount() (uint64, error) {
|
||||
return view.ActiveUserSessions(v.Db, userSessionTable)
|
||||
}
|
||||
|
Reference in New Issue
Block a user