feat: handle instanceID in projections (#3442)

* feat: handle instanceID in projections

* rename functions

* fix key lock

* fix import
This commit is contained in:
Livio Amstutz
2022-04-19 08:26:12 +02:00
committed by GitHub
parent c25d853820
commit 1305c14e49
120 changed files with 2078 additions and 1209 deletions

View File

@@ -61,12 +61,12 @@ type privacyPolicyProvider interface {
}
type userSessionViewProvider interface {
UserSessionByIDs(string, string) (*user_view_model.UserSessionView, error)
UserSessionsByAgentID(string) ([]*user_view_model.UserSessionView, error)
UserSessionByIDs(string, string, string) (*user_view_model.UserSessionView, error)
UserSessionsByAgentID(string, string) ([]*user_view_model.UserSessionView, error)
PrefixAvatarURL() string
}
type userViewProvider interface {
UserByID(string) (*user_view_model.UserView, error)
UserByID(string, string) (*user_view_model.UserView, error)
PrefixAvatarURL() string
}
@@ -79,7 +79,7 @@ type lockoutPolicyViewProvider interface {
}
type idpProviderViewProvider interface {
IDPProvidersByAggregateIDAndState(string, iam_model.IDPConfigState) ([]*iam_view_model.IDPProviderView, error)
IDPProvidersByAggregateIDAndState(string, string, iam_model.IDPConfigState) ([]*iam_view_model.IDPProviderView, error)
}
type userEventProvider interface {
@@ -102,7 +102,7 @@ type userGrantProvider interface {
type projectProvider interface {
ProjectByOIDCClientID(context.Context, string) (*query.Project, error)
OrgProjectMappingByIDs(orgID, projectID string) (*project_view_model.OrgProjectMapping, error)
OrgProjectMappingByIDs(orgID, projectID, instanceID string) (*project_view_model.OrgProjectMapping, error)
}
type applicationProvider interface {
@@ -596,7 +596,7 @@ func (repo *AuthRequestRepo) fillPolicies(ctx context.Context, request *domain.A
}
func (repo *AuthRequestRepo) tryUsingOnlyUserSession(request *domain.AuthRequest) error {
userSessions, err := userSessionsByUserAgentID(repo.UserSessionViewProvider, request.AgentID)
userSessions, err := userSessionsByUserAgentID(repo.UserSessionViewProvider, request.AgentID, request.InstanceID)
if err != nil {
return err
}
@@ -618,9 +618,9 @@ func (repo *AuthRequestRepo) checkLoginName(ctx context.Context, request *domain
if request.RequestedOrgID != "" {
preferredLoginName += "@" + request.RequestedPrimaryDomain
}
user, err = repo.View.UserByLoginNameAndResourceOwner(preferredLoginName, request.RequestedOrgID)
user, err = repo.View.UserByLoginNameAndResourceOwner(preferredLoginName, request.RequestedOrgID, request.InstanceID)
} else {
user, err = repo.View.UserByLoginName(loginName)
user, err = repo.View.UserByLoginName(loginName, request.InstanceID)
if err == nil {
err = repo.checkLoginPolicyWithResourceOwner(ctx, request, user)
if err != nil {
@@ -696,9 +696,9 @@ func (repo *AuthRequestRepo) checkSelectedExternalIDP(request *domain.AuthReques
func (repo *AuthRequestRepo) checkExternalUserLogin(ctx context.Context, request *domain.AuthRequest, idpConfigID, externalUserID string) (err error) {
externalIDP := new(user_view_model.ExternalIDPView)
if request.RequestedOrgID != "" {
externalIDP, err = repo.View.ExternalIDPByExternalUserIDAndIDPConfigIDAndResourceOwner(externalUserID, idpConfigID, request.RequestedOrgID)
externalIDP, err = repo.View.ExternalIDPByExternalUserIDAndIDPConfigIDAndResourceOwner(externalUserID, idpConfigID, request.RequestedOrgID, request.InstanceID)
} else {
externalIDP, err = repo.View.ExternalIDPByExternalUserIDAndIDPConfigID(externalUserID, idpConfigID)
externalIDP, err = repo.View.ExternalIDPByExternalUserIDAndIDPConfigID(externalUserID, idpConfigID, request.InstanceID)
}
if err != nil {
return err
@@ -828,7 +828,7 @@ func (repo *AuthRequestRepo) nextSteps(ctx context.Context, request *domain.Auth
}
func (repo *AuthRequestRepo) usersForUserSelection(request *domain.AuthRequest) ([]domain.UserSelection, error) {
userSessions, err := userSessionsByUserAgentID(repo.UserSessionViewProvider, request.AgentID)
userSessions, err := userSessionsByUserAgentID(repo.UserSessionViewProvider, request.AgentID, request.InstanceID)
if err != nil {
return nil, err
}
@@ -1044,13 +1044,13 @@ func setOrgID(orgViewProvider orgViewProvider, request *domain.AuthRequest) erro
func getLoginPolicyIDPProviders(provider idpProviderViewProvider, iamID, orgID string, defaultPolicy bool) ([]*iam_model.IDPProviderView, error) {
if defaultPolicy {
idpProviders, err := provider.IDPProvidersByAggregateIDAndState(iamID, iam_model.IDPConfigStateActive)
idpProviders, err := provider.IDPProvidersByAggregateIDAndState(iamID, iamID, iam_model.IDPConfigStateActive)
if err != nil {
return nil, err
}
return iam_view_model.IDPProviderViewsToModel(idpProviders), nil
}
idpProviders, err := provider.IDPProvidersByAggregateIDAndState(orgID, iam_model.IDPConfigStateActive)
idpProviders, err := provider.IDPProvidersByAggregateIDAndState(orgID, iamID, iam_model.IDPConfigStateActive)
if err != nil {
return nil, err
}
@@ -1071,8 +1071,8 @@ func checkVerificationTime(verificationTime time.Time, lifetime time.Duration) b
return verificationTime.Add(lifetime).After(time.Now().UTC())
}
func userSessionsByUserAgentID(provider userSessionViewProvider, agentID string) ([]*user_model.UserSessionView, error) {
session, err := provider.UserSessionsByAgentID(agentID)
func userSessionsByUserAgentID(provider userSessionViewProvider, agentID, instanceID string) ([]*user_model.UserSessionView, error) {
session, err := provider.UserSessionsByAgentID(agentID, instanceID)
if err != nil {
return nil, err
}
@@ -1080,7 +1080,7 @@ func userSessionsByUserAgentID(provider userSessionViewProvider, agentID string)
}
func userSessionByIDs(ctx context.Context, provider userSessionViewProvider, eventProvider userEventProvider, agentID string, user *user_model.UserView) (*user_model.UserSessionView, error) {
session, err := provider.UserSessionByIDs(agentID, user.ID)
session, err := provider.UserSessionByIDs(agentID, user.ID, authz.GetInstance(ctx).InstanceID())
if err != nil {
if !errors.IsNotFound(err) {
return nil, err
@@ -1156,7 +1156,7 @@ func activeUserByID(ctx context.Context, userViewProvider userViewProvider, user
}
func userByID(ctx context.Context, viewProvider userViewProvider, eventProvider userEventProvider, userID string) (*user_model.UserView, error) {
user, viewErr := viewProvider.UserByID(userID)
user, viewErr := viewProvider.UserByID(userID, authz.GetInstance(ctx).InstanceID())
if viewErr != nil && !errors.IsNotFound(viewErr) {
return nil, viewErr
} else if user == nil {
@@ -1254,7 +1254,7 @@ func projectRequired(ctx context.Context, request *domain.AuthRequest, projectPr
if !project.HasProjectCheck {
return false, nil
}
_, err = projectProvider.OrgProjectMappingByIDs(request.UserOrgID, project.ID)
_, err = projectProvider.OrgProjectMappingByIDs(request.UserOrgID, project.ID, request.InstanceID)
if errors.IsNotFound(err) {
return true, nil
}

View File

@@ -24,11 +24,11 @@ import (
type mockViewNoUserSession struct{}
func (m *mockViewNoUserSession) UserSessionByIDs(string, string) (*user_view_model.UserSessionView, error) {
func (m *mockViewNoUserSession) UserSessionByIDs(string, string, string) (*user_view_model.UserSessionView, error) {
return nil, errors.ThrowNotFound(nil, "id", "user session not found")
}
func (m *mockViewNoUserSession) UserSessionsByAgentID(string) ([]*user_view_model.UserSessionView, error) {
func (m *mockViewNoUserSession) UserSessionsByAgentID(string, string) ([]*user_view_model.UserSessionView, error) {
return nil, nil
}
@@ -38,11 +38,11 @@ func (m *mockViewNoUserSession) PrefixAvatarURL() string {
type mockViewErrUserSession struct{}
func (m *mockViewErrUserSession) UserSessionByIDs(string, string) (*user_view_model.UserSessionView, error) {
func (m *mockViewErrUserSession) UserSessionByIDs(string, string, string) (*user_view_model.UserSessionView, error) {
return nil, errors.ThrowInternal(nil, "id", "internal error")
}
func (m *mockViewErrUserSession) UserSessionsByAgentID(string) ([]*user_view_model.UserSessionView, error) {
func (m *mockViewErrUserSession) UserSessionsByAgentID(string, string) ([]*user_view_model.UserSessionView, error) {
return nil, errors.ThrowInternal(nil, "id", "internal error")
}
@@ -65,7 +65,7 @@ type mockUser struct {
ResourceOwner string
}
func (m *mockViewUserSession) UserSessionByIDs(string, string) (*user_view_model.UserSessionView, error) {
func (m *mockViewUserSession) UserSessionByIDs(string, string, string) (*user_view_model.UserSessionView, error) {
return &user_view_model.UserSessionView{
ExternalLoginVerification: m.ExternalLoginVerification,
PasswordlessVerification: m.PasswordlessVerification,
@@ -75,7 +75,7 @@ func (m *mockViewUserSession) UserSessionByIDs(string, string) (*user_view_model
}, nil
}
func (m *mockViewUserSession) UserSessionsByAgentID(string) ([]*user_view_model.UserSessionView, error) {
func (m *mockViewUserSession) UserSessionsByAgentID(string, string) ([]*user_view_model.UserSessionView, error) {
sessions := make([]*user_view_model.UserSessionView, len(m.Users))
for i, user := range m.Users {
sessions[i] = &user_view_model.UserSessionView{
@@ -93,7 +93,7 @@ func (m *mockViewUserSession) PrefixAvatarURL() string {
type mockViewNoUser struct{}
func (m *mockViewNoUser) UserByID(string) (*user_view_model.UserView, error) {
func (m *mockViewNoUser) UserByID(string, string) (*user_view_model.UserView, error) {
return nil, errors.ThrowNotFound(nil, "id", "user not found")
}
@@ -156,7 +156,7 @@ func (m *mockLockoutPolicy) LockoutPolicyByOrg(context.Context, string) (*query.
return m.policy, nil
}
func (m *mockViewUser) UserByID(string) (*user_view_model.UserView, error) {
func (m *mockViewUser) UserByID(string, string) (*user_view_model.UserView, error) {
return &user_view_model.UserView{
State: int32(user_model.UserStateActive),
UserName: "UserName",
@@ -232,7 +232,7 @@ func (m *mockProject) ProjectByOIDCClientID(ctx context.Context, s string) (*que
return &query.Project{HasProjectCheck: m.projectCheck}, nil
}
func (m *mockProject) OrgProjectMappingByIDs(orgID, projectID string) (*proj_view_model.OrgProjectMapping, error) {
func (m *mockProject) OrgProjectMappingByIDs(orgID, projectID, instanceID string) (*proj_view_model.OrgProjectMapping, error) {
if m.hasProject {
return &proj_view_model.OrgProjectMapping{OrgID: orgID, ProjectID: projectID}, nil
}

View File

@@ -23,7 +23,7 @@ type OrgRepository struct {
}
func (repo *OrgRepository) GetIDPConfigByID(ctx context.Context, idpConfigID string) (*iam_model.IDPConfigView, error) {
idpConfig, err := repo.View.IDPConfigByID(idpConfigID)
idpConfig, err := repo.View.IDPConfigByID(idpConfigID, authz.GetInstance(ctx).InstanceID())
if err != nil {
return nil, err
}

View File

@@ -6,16 +6,16 @@ import (
"github.com/caos/logging"
"github.com/caos/zitadel/internal/api/authz"
"github.com/caos/zitadel/internal/auth/repository/eventsourcing/view"
"github.com/caos/zitadel/internal/crypto"
"github.com/caos/zitadel/internal/domain"
"github.com/caos/zitadel/internal/errors"
v1 "github.com/caos/zitadel/internal/eventstore/v1"
"github.com/caos/zitadel/internal/eventstore/v1/models"
usr_view "github.com/caos/zitadel/internal/user/repository/view"
"github.com/caos/zitadel/internal/auth/repository/eventsourcing/view"
"github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/telemetry/tracing"
usr_model "github.com/caos/zitadel/internal/user/model"
usr_view "github.com/caos/zitadel/internal/user/repository/view"
"github.com/caos/zitadel/internal/user/repository/view/model"
)
@@ -31,7 +31,7 @@ func (r *RefreshTokenRepo) RefreshTokenByID(ctx context.Context, refreshToken st
if err != nil {
return nil, err
}
tokenView, viewErr := r.View.RefreshTokenByID(tokenID)
tokenView, viewErr := r.View.RefreshTokenByID(tokenID, authz.GetInstance(ctx).InstanceID())
if viewErr != nil && !errors.IsNotFound(viewErr) {
return nil, viewErr
}
@@ -41,7 +41,7 @@ func (r *RefreshTokenRepo) RefreshTokenByID(ctx context.Context, refreshToken st
tokenView.UserID = userID
}
events, esErr := r.getUserEvents(ctx, userID, tokenView.Sequence)
events, esErr := r.getUserEvents(ctx, userID, tokenView.InstanceID, tokenView.Sequence)
if errors.IsNotFound(viewErr) && len(events) == 0 {
return nil, errors.ThrowNotFound(nil, "EVENT-BHB52", "Errors.User.RefreshToken.Invalid")
}
@@ -68,7 +68,7 @@ func (r *RefreshTokenRepo) SearchMyRefreshTokens(ctx context.Context, userID str
if err != nil {
return nil, err
}
sequence, err := r.View.GetLatestRefreshTokenSequence()
sequence, err := r.View.GetLatestRefreshTokenSequence(authz.GetInstance(ctx).InstanceID())
logging.Log("EVENT-GBdn4").OnError(err).WithField("traceID", tracing.TraceIDFromCtx(ctx)).Warn("could not read latest refresh token sequence")
request.Queries = append(request.Queries, &usr_model.RefreshTokenSearchQuery{Key: usr_model.RefreshTokenSearchKeyUserID, Method: domain.SearchMethodEquals, Value: userID})
tokens, count, err := r.View.SearchRefreshTokens(request)
@@ -85,8 +85,8 @@ func (r *RefreshTokenRepo) SearchMyRefreshTokens(ctx context.Context, userID str
}, nil
}
func (r *RefreshTokenRepo) getUserEvents(ctx context.Context, userID string, sequence uint64) ([]*models.Event, error) {
query, err := usr_view.UserByIDQuery(userID, sequence)
func (r *RefreshTokenRepo) getUserEvents(ctx context.Context, userID, instanceID string, sequence uint64) ([]*models.Event, error) {
query, err := usr_view.UserByIDQuery(userID, instanceID, sequence)
if err != nil {
return nil, err
}

View File

@@ -2,18 +2,18 @@ package eventstore
import (
"context"
"github.com/caos/zitadel/internal/eventstore/v1"
"time"
"github.com/caos/zitadel/internal/eventstore/v1/models"
usr_view "github.com/caos/zitadel/internal/user/repository/view"
"github.com/caos/logging"
"github.com/caos/zitadel/internal/api/authz"
"github.com/caos/zitadel/internal/auth/repository/eventsourcing/view"
"github.com/caos/zitadel/internal/errors"
v1 "github.com/caos/zitadel/internal/eventstore/v1"
"github.com/caos/zitadel/internal/eventstore/v1/models"
"github.com/caos/zitadel/internal/telemetry/tracing"
usr_model "github.com/caos/zitadel/internal/user/model"
usr_view "github.com/caos/zitadel/internal/user/repository/view"
"github.com/caos/zitadel/internal/user/repository/view/model"
)
@@ -34,7 +34,7 @@ func (repo *TokenRepo) IsTokenValid(ctx context.Context, userID, tokenID string)
}
func (repo *TokenRepo) TokenByID(ctx context.Context, userID, tokenID string) (*usr_model.TokenView, error) {
token, viewErr := repo.View.TokenByID(tokenID)
token, viewErr := repo.View.TokenByID(tokenID, authz.GetInstance(ctx).InstanceID())
if viewErr != nil && !errors.IsNotFound(viewErr) {
return nil, viewErr
}
@@ -44,7 +44,7 @@ func (repo *TokenRepo) TokenByID(ctx context.Context, userID, tokenID string) (*
token.UserID = userID
}
events, esErr := repo.getUserEvents(ctx, userID, token.Sequence)
events, esErr := repo.getUserEvents(ctx, userID, token.InstanceID, token.Sequence)
if errors.IsNotFound(viewErr) && len(events) == 0 {
return nil, errors.ThrowNotFound(nil, "EVENT-4T90g", "Errors.Token.NotFound")
}
@@ -66,8 +66,8 @@ func (repo *TokenRepo) TokenByID(ctx context.Context, userID, tokenID string) (*
return model.TokenViewToModel(token), nil
}
func (r *TokenRepo) getUserEvents(ctx context.Context, userID string, sequence uint64) ([]*models.Event, error) {
query, err := usr_view.UserByIDQuery(userID, sequence)
func (r *TokenRepo) getUserEvents(ctx context.Context, userID, instanceID string, sequence uint64) ([]*models.Event, error) {
query, err := usr_view.UserByIDQuery(userID, instanceID, sequence)
if err != nil {
return nil, err
}

View File

@@ -3,6 +3,7 @@ package eventstore
import (
"context"
"github.com/caos/zitadel/internal/api/authz"
"github.com/caos/zitadel/internal/auth/repository/eventsourcing/view"
"github.com/caos/zitadel/internal/config/systemdefaults"
"github.com/caos/zitadel/internal/domain"
@@ -26,7 +27,7 @@ func (repo *UserRepo) Health(ctx context.Context) error {
}
func (repo *UserRepo) UserSessionUserIDsByAgentID(ctx context.Context, agentID string) ([]string, error) {
userSessions, err := repo.View.UserSessionsByAgentID(agentID)
userSessions, err := repo.View.UserSessionsByAgentID(agentID, authz.GetInstance(ctx).InstanceID())
if err != nil {
return nil, err
}
@@ -44,7 +45,7 @@ func (repo *UserRepo) UserEventsByID(ctx context.Context, id string, sequence ui
}
func (r *UserRepo) getUserEvents(ctx context.Context, userID string, sequence uint64) ([]*models.Event, error) {
query, err := usr_view.UserByIDQuery(userID, sequence)
query, err := usr_view.UserByIDQuery(userID, authz.GetInstance(ctx).InstanceID(), sequence)
if err != nil {
return nil, err
}

View File

@@ -14,7 +14,7 @@ type UserSessionRepo struct {
}
func (repo *UserSessionRepo) GetMyUserSessions(ctx context.Context) ([]*usr_model.UserSessionView, error) {
userSessions, err := repo.View.UserSessionsByAgentID(authz.GetCtxData(ctx).AgentID)
userSessions, err := repo.View.UserSessionsByAgentID(authz.GetCtxData(ctx).AgentID, authz.GetInstance(ctx).InstanceID())
if err != nil {
return nil, err
}