fix(login): correctly reload policies on auth request (#7839)

This commit is contained in:
Livio Spring 2024-04-24 10:44:55 +02:00 committed by adlerhurst
parent e7904057e0
commit 220c09d941
3 changed files with 26 additions and 11 deletions

View File

@ -656,7 +656,7 @@ func (repo *AuthRequestRepo) fillPolicies(ctx context.Context, request *domain.A
} }
} }
if request.LoginPolicy == nil || len(request.AllowedExternalIDPs) == 0 { if request.LoginPolicy == nil || len(request.AllowedExternalIDPs) == 0 || request.PolicyOrgID() != orgID {
loginPolicy, idpProviders, err := repo.getLoginPolicyAndIDPProviders(ctx, orgID) loginPolicy, idpProviders, err := repo.getLoginPolicyAndIDPProviders(ctx, orgID)
if err != nil { if err != nil {
return err return err
@ -666,21 +666,21 @@ func (repo *AuthRequestRepo) fillPolicies(ctx context.Context, request *domain.A
request.AllowedExternalIDPs = idpProviders request.AllowedExternalIDPs = idpProviders
} }
} }
if request.LockoutPolicy == nil { if request.LockoutPolicy == nil || request.PolicyOrgID() != orgID {
lockoutPolicy, err := repo.getLockoutPolicy(ctx, orgID) lockoutPolicy, err := repo.getLockoutPolicy(ctx, orgID)
if err != nil { if err != nil {
return err return err
} }
request.LockoutPolicy = lockoutPolicyToDomain(lockoutPolicy) request.LockoutPolicy = lockoutPolicyToDomain(lockoutPolicy)
} }
if request.PrivacyPolicy == nil { if request.PrivacyPolicy == nil || request.PolicyOrgID() != orgID {
privacyPolicy, err := repo.GetPrivacyPolicy(ctx, orgID) privacyPolicy, err := repo.GetPrivacyPolicy(ctx, orgID)
if err != nil { if err != nil {
return err return err
} }
request.PrivacyPolicy = privacyPolicy request.PrivacyPolicy = privacyPolicy
} }
if request.LabelPolicy == nil { if request.LabelPolicy == nil || request.PolicyOrgID() != orgID {
labelPolicy, err := repo.getLabelPolicy(ctx, request.PrivateLabelingOrgID(orgID)) labelPolicy, err := repo.getLabelPolicy(ctx, request.PrivateLabelingOrgID(orgID))
if err != nil { if err != nil {
return err return err
@ -694,13 +694,14 @@ func (repo *AuthRequestRepo) fillPolicies(ctx context.Context, request *domain.A
} }
request.DefaultTranslations = defaultLoginTranslations request.DefaultTranslations = defaultLoginTranslations
} }
if len(request.OrgTranslations) == 0 { if len(request.OrgTranslations) == 0 || request.PolicyOrgID() != orgID {
orgLoginTranslations, err := repo.getLoginTexts(ctx, orgID) orgLoginTranslations, err := repo.getLoginTexts(ctx, orgID)
if err != nil { if err != nil {
return err return err
} }
request.OrgTranslations = orgLoginTranslations request.OrgTranslations = orgLoginTranslations
} }
request.SetPolicyOrgID(orgID)
repo.AuthRequests.CacheAuthRequest(ctx, request) repo.AuthRequests.CacheAuthRequest(ctx, request)
return nil return nil
} }
@ -887,7 +888,7 @@ func (repo *AuthRequestRepo) checkLoginNameInputForResourceOwner(ctx context.Con
} }
func (repo *AuthRequestRepo) checkLoginPolicyWithResourceOwner(ctx context.Context, request *domain.AuthRequest, resourceOwner string) (err error) { func (repo *AuthRequestRepo) checkLoginPolicyWithResourceOwner(ctx context.Context, request *domain.AuthRequest, resourceOwner string) (err error) {
if request.LoginPolicy == nil { if request.LoginPolicy == nil || request.PolicyOrgID() != resourceOwner {
loginPolicy, idps, err := repo.getLoginPolicyAndIDPProviders(ctx, resourceOwner) loginPolicy, idps, err := repo.getLoginPolicyAndIDPProviders(ctx, resourceOwner)
if err != nil { if err != nil {
return err return err

View File

@ -24,16 +24,20 @@ type AuthRequestCache struct {
} }
func Start(dbClient *database.DB, amountOfCachedAuthRequests uint16) *AuthRequestCache { func Start(dbClient *database.DB, amountOfCachedAuthRequests uint16) *AuthRequestCache {
cache := &AuthRequestCache{
client: dbClient,
}
idCache, err := lru.New[string, *domain.AuthRequest](int(amountOfCachedAuthRequests)) idCache, err := lru.New[string, *domain.AuthRequest](int(amountOfCachedAuthRequests))
logging.OnError(err).Info("auth request cache disabled") logging.OnError(err).Info("auth request cache disabled")
if err == nil {
cache.idCache = idCache
}
codeCache, err := lru.New[string, *domain.AuthRequest](int(amountOfCachedAuthRequests)) codeCache, err := lru.New[string, *domain.AuthRequest](int(amountOfCachedAuthRequests))
logging.OnError(err).Info("auth request cache disabled") logging.OnError(err).Info("auth request cache disabled")
if err == nil {
return &AuthRequestCache{ cache.codeCache = codeCache
client: dbClient,
idCache: idCache,
codeCache: codeCache,
} }
return cache
} }
func (c *AuthRequestCache) Health(ctx context.Context) error { func (c *AuthRequestCache) Health(ctx context.Context) error {

View File

@ -56,6 +56,16 @@ type AuthRequest struct {
DefaultTranslations []*CustomText DefaultTranslations []*CustomText
OrgTranslations []*CustomText OrgTranslations []*CustomText
SAMLRequestID string SAMLRequestID string
// orgID the policies were last loaded with
policyOrgID string
}
func (a *AuthRequest) SetPolicyOrgID(id string) {
a.policyOrgID = id
}
func (a *AuthRequest) PolicyOrgID() string {
return a.policyOrgID
} }
type ExternalUser struct { type ExternalUser struct {