fix: use current sequence for refetching of events (#5772)

* fix: use current sequence for refetching of events

* fix: use client ids
This commit is contained in:
Livio Spring
2023-04-28 16:28:13 +02:00
committed by GitHub
parent c8c5cf3c5f
commit 458a383de2
28 changed files with 273 additions and 107 deletions

View File

@@ -23,6 +23,7 @@ import (
"github.com/zitadel/zitadel/internal/telemetry/tracing"
user_model "github.com/zitadel/zitadel/internal/user/model"
user_view_model "github.com/zitadel/zitadel/internal/user/repository/view/model"
"github.com/zitadel/zitadel/internal/view/repository"
)
const unknownUserID = "UNKNOWN"
@@ -64,7 +65,9 @@ type privacyPolicyProvider interface {
type userSessionViewProvider interface {
UserSessionByIDs(string, string, string) (*user_view_model.UserSessionView, error)
UserSessionsByAgentID(string, string) ([]*user_view_model.UserSessionView, error)
GetLatestUserSessionSequence(ctx context.Context, instanceID string) (*repository.CurrentSequence, error)
}
type userViewProvider interface {
UserByID(string, string) (*user_view_model.UserView, error)
}
@@ -654,7 +657,7 @@ func (repo *AuthRequestRepo) checkLoginName(ctx context.Context, request *domain
preferredLoginName += "@" + request.RequestedPrimaryDomain
}
}
user, err = repo.checkLoginNameInputForResourceOwner(request, preferredLoginName)
user, err = repo.checkLoginNameInputForResourceOwner(ctx, request, preferredLoginName)
} else {
user, err = repo.checkLoginNameInput(ctx, request, preferredLoginName)
}
@@ -729,12 +732,12 @@ func (repo *AuthRequestRepo) checkDomainDiscovery(ctx context.Context, request *
func (repo *AuthRequestRepo) checkLoginNameInput(ctx context.Context, request *domain.AuthRequest, loginNameInput string) (*user_view_model.UserView, error) {
// always check the loginname first
user, err := repo.View.UserByLoginName(loginNameInput, request.InstanceID)
user, err := repo.View.UserByLoginName(ctx, loginNameInput, request.InstanceID)
if err == nil {
// and take the user regardless if there would be a user with that email or phone
return user, repo.checkLoginPolicyWithResourceOwner(ctx, request, user.ResourceOwner)
}
user, emailErr := repo.View.UserByEmail(loginNameInput, request.InstanceID)
user, emailErr := repo.View.UserByEmail(ctx, loginNameInput, request.InstanceID)
if emailErr == nil {
// if there was a single user with the specified email
// load and check the login policy
@@ -747,7 +750,7 @@ func (repo *AuthRequestRepo) checkLoginNameInput(ctx context.Context, request *d
return user, nil
}
}
user, phoneErr := repo.View.UserByPhone(loginNameInput, request.InstanceID)
user, phoneErr := repo.View.UserByPhone(ctx, loginNameInput, request.InstanceID)
if phoneErr == nil {
// if there was a single user with the specified phone
// load and check the login policy
@@ -765,9 +768,9 @@ func (repo *AuthRequestRepo) checkLoginNameInput(ctx context.Context, request *d
return nil, err
}
func (repo *AuthRequestRepo) checkLoginNameInputForResourceOwner(request *domain.AuthRequest, loginNameInput string) (*user_view_model.UserView, error) {
func (repo *AuthRequestRepo) checkLoginNameInputForResourceOwner(ctx context.Context, request *domain.AuthRequest, loginNameInput string) (*user_view_model.UserView, error) {
// always check the loginname first
user, err := repo.View.UserByLoginNameAndResourceOwner(loginNameInput, request.RequestedOrgID, request.InstanceID)
user, err := repo.View.UserByLoginNameAndResourceOwner(ctx, loginNameInput, request.RequestedOrgID, request.InstanceID)
if err == nil {
// and take the user regardless if there would be a user with that email or phone
return user, nil
@@ -775,7 +778,7 @@ func (repo *AuthRequestRepo) checkLoginNameInputForResourceOwner(request *domain
if request.LoginPolicy != nil && !request.LoginPolicy.DisableLoginWithEmail {
// if login by email is allowed and there was a single user with the specified email
// take that user (and ignore possible phone number matches)
user, emailErr := repo.View.UserByEmailAndResourceOwner(loginNameInput, request.RequestedOrgID, request.InstanceID)
user, emailErr := repo.View.UserByEmailAndResourceOwner(ctx, loginNameInput, request.RequestedOrgID, request.InstanceID)
if emailErr == nil {
return user, nil
}
@@ -783,7 +786,7 @@ func (repo *AuthRequestRepo) checkLoginNameInputForResourceOwner(request *domain
if request.LoginPolicy != nil && !request.LoginPolicy.DisableLoginWithPhone {
// if login by phone is allowed and there was a single user with the specified phone
// take that user
user, phoneErr := repo.View.UserByPhoneAndResourceOwner(loginNameInput, request.RequestedOrgID, request.InstanceID)
user, phoneErr := repo.View.UserByPhoneAndResourceOwner(ctx, loginNameInput, request.RequestedOrgID, request.InstanceID)
if phoneErr == nil {
return user, nil
}
@@ -1298,12 +1301,20 @@ func userSessionsByUserAgentID(provider userSessionViewProvider, agentID, instan
}
func userSessionByIDs(ctx context.Context, provider userSessionViewProvider, eventProvider userEventProvider, agentID string, user *user_model.UserView) (*user_model.UserSessionView, error) {
session, err := provider.UserSessionByIDs(agentID, user.ID, authz.GetInstance(ctx).InstanceID())
instanceID := authz.GetInstance(ctx).InstanceID()
session, err := provider.UserSessionByIDs(agentID, user.ID, instanceID)
if err != nil {
if !errors.IsNotFound(err) {
return nil, err
}
sequence, err := provider.GetLatestUserSessionSequence(ctx, instanceID)
logging.WithFields("instanceID", instanceID, "userID", user.ID).
OnError(err).
Errorf("could not get current sequence for userSessionByIDs")
session = &user_view_model.UserSessionView{UserAgentID: agentID, UserID: user.ID}
if sequence != nil {
session.Sequence = sequence.CurrentSequence
}
}
events, err := eventProvider.UserEventsByID(ctx, user.ID, session.Sequence)
if err != nil {

View File

@@ -19,6 +19,7 @@ import (
user_model "github.com/zitadel/zitadel/internal/user/model"
user_es_model "github.com/zitadel/zitadel/internal/user/repository/eventsourcing/model"
user_view_model "github.com/zitadel/zitadel/internal/user/repository/view/model"
"github.com/zitadel/zitadel/internal/view/repository"
)
var (
@@ -35,6 +36,10 @@ func (m *mockViewNoUserSession) UserSessionsByAgentID(string, string) ([]*user_v
return nil, nil
}
func (m *mockViewNoUserSession) GetLatestUserSessionSequence(ctx context.Context, instanceID string) (*repository.CurrentSequence, error) {
return &repository.CurrentSequence{}, nil
}
type mockViewErrUserSession struct{}
func (m *mockViewErrUserSession) UserSessionByIDs(string, string, string) (*user_view_model.UserSessionView, error) {
@@ -45,6 +50,10 @@ func (m *mockViewErrUserSession) UserSessionsByAgentID(string, string) ([]*user_
return nil, errors.ThrowInternal(nil, "id", "internal error")
}
func (m *mockViewErrUserSession) GetLatestUserSessionSequence(ctx context.Context, instanceID string) (*repository.CurrentSequence, error) {
return &repository.CurrentSequence{}, nil
}
type mockViewUserSession struct {
ExternalLoginVerification time.Time
PasswordlessVerification time.Time
@@ -82,6 +91,10 @@ func (m *mockViewUserSession) UserSessionsByAgentID(string, string) ([]*user_vie
return sessions, nil
}
func (m *mockViewUserSession) GetLatestUserSessionSequence(ctx context.Context, instanceID string) (*repository.CurrentSequence, error) {
return &repository.CurrentSequence{}, nil
}
type mockViewNoUser struct{}
func (m *mockViewNoUser) UserByID(string, string) (*user_view_model.UserView, error) {

View File

@@ -42,15 +42,24 @@ func (r *RefreshTokenRepo) RefreshTokenByToken(ctx context.Context, refreshToken
}
func (r *RefreshTokenRepo) RefreshTokenByID(ctx context.Context, tokenID, userID string) (*usr_model.RefreshTokenView, error) {
tokenView, viewErr := r.View.RefreshTokenByID(tokenID, authz.GetInstance(ctx).InstanceID())
instanceID := authz.GetInstance(ctx).InstanceID()
tokenView, viewErr := r.View.RefreshTokenByID(tokenID, instanceID)
if viewErr != nil && !errors.IsNotFound(viewErr) {
return nil, viewErr
}
if errors.IsNotFound(viewErr) {
sequence, err := r.View.GetLatestRefreshTokenSequence(ctx, instanceID)
logging.WithFields("instanceID", instanceID, "userID", userID, "tokenID", tokenID).
OnError(err).
Errorf("could not get current sequence for RefreshTokenByID")
tokenView = new(model.RefreshTokenView)
tokenView.ID = tokenID
tokenView.UserID = userID
tokenView.InstanceID = authz.GetInstance(ctx).InstanceID()
tokenView.InstanceID = instanceID
if sequence != nil {
tokenView.Sequence = sequence.CurrentSequence
}
}
events, esErr := r.getUserEvents(ctx, userID, tokenView.InstanceID, tokenView.Sequence)
@@ -80,7 +89,7 @@ func (r *RefreshTokenRepo) SearchMyRefreshTokens(ctx context.Context, userID str
if err != nil {
return nil, err
}
sequence, err := r.View.GetLatestRefreshTokenSequence(authz.GetInstance(ctx).InstanceID())
sequence, err := r.View.GetLatestRefreshTokenSequence(ctx, authz.GetInstance(ctx).InstanceID())
logging.Log("EVENT-GBdn4").OnError(err).WithField("traceID", tracing.TraceIDFromCtx(ctx)).Warn("could not read latest refresh token sequence")
request.Queries = append(request.Queries, &usr_model.RefreshTokenSearchQuery{Key: usr_model.RefreshTokenSearchKeyUserID, Method: domain.SearchMethodEquals, Value: userID})
tokens, count, err := r.View.SearchRefreshTokens(request)

View File

@@ -34,15 +34,25 @@ func (repo *TokenRepo) IsTokenValid(ctx context.Context, userID, tokenID string)
}
func (repo *TokenRepo) TokenByIDs(ctx context.Context, userID, tokenID string) (*usr_model.TokenView, error) {
token, viewErr := repo.View.TokenByIDs(tokenID, userID, authz.GetInstance(ctx).InstanceID())
instanceID := authz.GetInstance(ctx).InstanceID()
token, viewErr := repo.View.TokenByIDs(tokenID, userID, instanceID)
if viewErr != nil && !errors.IsNotFound(viewErr) {
return nil, viewErr
}
if errors.IsNotFound(viewErr) {
sequence, err := repo.View.GetLatestTokenSequence(ctx, instanceID)
logging.WithFields("instanceID", instanceID, "userID", userID, "tokenID", tokenID).
OnError(err).
Errorf("could not get current sequence for TokenByIDs")
token = new(model.TokenView)
token.ID = tokenID
token.UserID = userID
token.InstanceID = authz.GetInstance(ctx).InstanceID()
token.InstanceID = instanceID
if sequence != nil {
token.Sequence = sequence.CurrentSequence
}
}
events, esErr := repo.getUserEvents(ctx, userID, token.InstanceID, token.Sequence)