fix: scheduling (#3978)

* fix: improve scheduling

* build pre-release

* fix: locker

* fix: user handler and print stack in case of panic in reducer

* chore: remove sentry

* fix: improve handler projection and implement tests

* more tests

* fix: race condition in tests

* Update internal/eventstore/repository/sql/query.go

Co-authored-by: Silvan <silvan.reusser@gmail.com>

* fix: implemented suggested changes

* fix: lock statement

Co-authored-by: Silvan <silvan.reusser@gmail.com>
This commit is contained in:
Livio Spring
2022-07-22 12:08:39 +02:00
committed by GitHub
parent 0cc548e3f8
commit aed7010508
83 changed files with 1494 additions and 1544 deletions

View File

@@ -73,30 +73,26 @@ func (m *Styling) CurrentSequence(instanceID string) (uint64, error) {
return sequence.CurrentSequence, nil
}
func (m *Styling) EventQuery() (*models.SearchQuery, error) {
sequences, err := m.view.GetLatestStylingSequences()
func (m *Styling) EventQuery(instanceIDs ...string) (*models.SearchQuery, error) {
sequences, err := m.view.GetLatestStylingSequences(instanceIDs...)
if err != nil {
return nil, err
}
query := models.NewSearchQuery()
instances := make([]string, 0)
searchQuery := models.NewSearchQuery()
for _, sequence := range sequences {
for _, instance := range instances {
if sequence.InstanceID == instance {
var seq uint64
for _, instanceID := range instanceIDs {
if sequence.InstanceID == instanceID {
seq = sequence.CurrentSequence
break
}
}
instances = append(instances, sequence.InstanceID)
query.AddQuery().
searchQuery.AddQuery().
AggregateTypeFilter(m.AggregateTypes()...).
LatestSequenceFilter(sequence.CurrentSequence).
LatestSequenceFilter(seq).
InstanceIDFilter(sequence.InstanceID)
}
return query.AddQuery().
AggregateTypeFilter(m.AggregateTypes()...).
LatestSequenceFilter(0).
ExcludedInstanceIDsFilter(instances...).
SearchQuery(), nil
return searchQuery, nil
}
func (m *Styling) Reduce(event *models.Event) (err error) {
@@ -299,7 +295,7 @@ func (m *Styling) generateColorPaletteRGBA255(hex string) map[string]string {
if ok {
palette["500"] = cssRGB(color500.RGB255())
}
color600, ok := colorful.MakeColor(gamut.Darker(defaultColor, 0.06))
if ok {
palette["600"] = cssRGB(color600.RGB255())

View File

@@ -15,15 +15,17 @@ type SpoolerConfig struct {
BulkLimit uint64
FailureCountUntilSkip uint64
ConcurrentWorkers int
ConcurrentInstances int
Handlers handler.Configs
}
func StartSpooler(c SpoolerConfig, es v1.Eventstore, view *view.View, sql *sql.DB, static static.Storage) *spooler.Spooler {
spoolerConfig := spooler.Config{
Eventstore: es,
Locker: &locker{dbClient: sql},
ConcurrentWorkers: c.ConcurrentWorkers,
ViewHandlers: handler.Register(c.Handlers, c.BulkLimit, c.FailureCountUntilSkip, view, es, static),
Eventstore: es,
Locker: &locker{dbClient: sql},
ConcurrentWorkers: c.ConcurrentWorkers,
ConcurrentInstances: c.ConcurrentInstances,
ViewHandlers: handler.Register(c.Handlers, c.BulkLimit, c.FailureCountUntilSkip, view, es, static),
}
spool := spoolerConfig.New()
spool.Start()

View File

@@ -19,8 +19,8 @@ func (v *View) latestSequence(viewName, instanceID string) (*repository.CurrentS
return repository.LatestSequence(v.Db, sequencesTable, viewName, instanceID)
}
func (v *View) latestSequences(viewName string) ([]*repository.CurrentSequence, error) {
return repository.LatestSequences(v.Db, sequencesTable, viewName)
func (v *View) latestSequences(viewName string, instanceIDs ...string) ([]*repository.CurrentSequence, error) {
return repository.LatestSequences(v.Db, sequencesTable, viewName, instanceIDs...)
}
func (v *View) AllCurrentSequences(db string) ([]*repository.CurrentSequence, error) {

View File

@@ -27,8 +27,8 @@ func (v *View) GetLatestStylingSequence(instanceID string) (*global_view.Current
return v.latestSequence(stylingTyble, instanceID)
}
func (v *View) GetLatestStylingSequences() ([]*global_view.CurrentSequence, error) {
return v.latestSequences(stylingTyble)
func (v *View) GetLatestStylingSequences(instanceIDs ...string) ([]*global_view.CurrentSequence, error) {
return v.latestSequences(stylingTyble, instanceIDs...)
}
func (v *View) ProcessedStylingSequence(event *models.Event) error {

View File

@@ -6,7 +6,6 @@ import (
"net/http"
"strings"
sentryhttp "github.com/getsentry/sentry-go/http"
"github.com/gorilla/mux"
"github.com/improbable-eng/grpc-web/go/grpcweb"
"github.com/zitadel/logging"
@@ -67,7 +66,6 @@ func (a *API) RegisterServer(ctx context.Context, grpcServer server.Server) erro
func (a *API) RegisterHandler(prefix string, handler http.Handler) {
prefix = strings.TrimSuffix(prefix, "/")
subRouter := a.router.PathPrefix(prefix).Name(prefix).Subrouter()
subRouter.Use(sentryhttp.New(sentryhttp.Options{}).Handle)
subRouter.PathPrefix("").Handler(http.StripPrefix(prefix, handler))
}

View File

@@ -8,7 +8,6 @@ import (
"strings"
"time"
sentryhttp "github.com/getsentry/sentry-go/http"
"github.com/gorilla/mux"
"github.com/zitadel/logging"
@@ -89,7 +88,7 @@ func NewHandler(commands *command.Commands, verifier *authz.TokenVerifier, authC
verifier.RegisterServer("Assets-API", "assets", AssetsService_AuthMethods)
router := mux.NewRouter()
router.Use(sentryhttp.New(sentryhttp.Options{}).Handle, instanceInterceptor)
router.Use(instanceInterceptor)
RegisterRoutes(router, h)
router.PathPrefix("/{owner}").Methods("GET").HandlerFunc(DownloadHandleFunc(h, h.GetFile()))
return http_util.CopyHeadersToContext(http_mw.CORSInterceptor(router))
@@ -117,6 +116,10 @@ func UploadHandleFunc(s AssetsService, uploader Uploader) func(http.ResponseWrit
ctx := r.Context()
ctxData := authz.GetCtxData(ctx)
err := r.ParseMultipartForm(maxMemory)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
file, handler, err := r.FormFile(paramFile)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)

View File

@@ -1,34 +0,0 @@
package middleware
import (
"context"
"github.com/getsentry/sentry-go"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func SentryHandler() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
return sendErrToSentry(ctx, req, handler)
}
}
func sendErrToSentry(ctx context.Context, req interface{}, handler grpc.UnaryHandler) (interface{}, error) {
resp, err := handler(ctx, req)
code := status.Code(err)
switch code {
case codes.Canceled,
codes.Unknown,
codes.DeadlineExceeded,
codes.ResourceExhausted,
codes.Aborted,
codes.Unimplemented,
codes.Internal,
codes.Unavailable,
codes.DataLoss:
sentry.CaptureException(err)
}
return resp, err
}

View File

@@ -30,7 +30,6 @@ func CreateServer(verifier *authz.TokenVerifier, authConfig authz.Config, querie
grpc_middleware.ChainUnaryServer(
middleware.DefaultTracingServer(),
middleware.MetricsHandler(metricTypes, grpc_api.Probes...),
middleware.SentryHandler(),
middleware.NoCacheInterceptor(),
middleware.ErrorHandler(),
middleware.InstanceInterceptor(queries, hostHeaderName, system_pb.SystemService_MethodPrefix),

View File

@@ -8,8 +8,10 @@ import (
"github.com/zitadel/zitadel/internal/auth/repository/eventsourcing/view"
sd "github.com/zitadel/zitadel/internal/config/systemdefaults"
v1 "github.com/zitadel/zitadel/internal/eventstore/v1"
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
"github.com/zitadel/zitadel/internal/eventstore/v1/query"
query2 "github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/view/repository"
)
type Configs map[string]*Config
@@ -75,3 +77,21 @@ func (h *handler) QueryLimit() uint64 {
func withInstanceID(ctx context.Context, instanceID string) context.Context {
return authz.WithInstanceID(ctx, instanceID)
}
func newSearchQuery(sequences []*repository.CurrentSequence, aggregateTypes []models.AggregateType, instanceIDs []string) *models.SearchQuery {
searchQuery := models.NewSearchQuery()
for _, sequence := range sequences {
var seq uint64
for _, instanceID := range instanceIDs {
if sequence.InstanceID == instanceID {
seq = sequence.CurrentSequence
break
}
}
searchQuery.AddQuery().
AggregateTypeFilter(aggregateTypes...).
LatestSequenceFilter(seq).
InstanceIDFilter(sequence.InstanceID)
}
return searchQuery
}

View File

@@ -62,31 +62,12 @@ func (i *IDPConfig) CurrentSequence(instanceID string) (uint64, error) {
return sequence.CurrentSequence, nil
}
func (i *IDPConfig) EventQuery() (*models.SearchQuery, error) {
sequences, err := i.view.GetLatestIDPConfigSequences()
func (i *IDPConfig) EventQuery(instanceIDs ...string) (*models.SearchQuery, error) {
sequences, err := i.view.GetLatestIDPConfigSequences(instanceIDs...)
if err != nil {
return nil, err
}
query := models.NewSearchQuery()
instances := make([]string, 0)
for _, sequence := range sequences {
for _, instance := range instances {
if sequence.InstanceID == instance {
break
}
}
instances = append(instances, sequence.InstanceID)
query.AddQuery().
AggregateTypeFilter(i.AggregateTypes()...).
LatestSequenceFilter(sequence.CurrentSequence).
InstanceIDFilter(sequence.InstanceID)
}
return query.AddQuery().
AggregateTypeFilter(i.AggregateTypes()...).
LatestSequenceFilter(0).
ExcludedInstanceIDsFilter(instances...).
SearchQuery(), nil
return newSearchQuery(sequences, i.AggregateTypes(), instanceIDs), nil
}
func (i *IDPConfig) Reduce(event *models.Event) (err error) {

View File

@@ -76,30 +76,13 @@ func (i *IDPProvider) CurrentSequence(instanceID string) (uint64, error) {
return sequence.CurrentSequence, nil
}
func (i *IDPProvider) EventQuery() (*models.SearchQuery, error) {
sequences, err := i.view.GetLatestIDPProviderSequences()
func (i *IDPProvider) EventQuery(instanceIDs ...string) (*models.SearchQuery, error) {
sequences, err := i.view.GetLatestIDPProviderSequences(instanceIDs...)
if err != nil {
return nil, err
}
query := es_models.NewSearchQuery()
instances := make([]string, 0)
for _, sequence := range sequences {
for _, instance := range instances {
if sequence.InstanceID == instance {
break
}
}
instances = append(instances, sequence.InstanceID)
query.AddQuery().
AggregateTypeFilter(i.AggregateTypes()...).
LatestSequenceFilter(sequence.CurrentSequence).
InstanceIDFilter(sequence.InstanceID)
}
return query.AddQuery().
AggregateTypeFilter(i.AggregateTypes()...).
LatestSequenceFilter(0).
ExcludedInstanceIDsFilter(instances...).
SearchQuery(), nil
return newSearchQuery(sequences, i.AggregateTypes(), instanceIDs), nil
}
func (i *IDPProvider) Reduce(event *models.Event) (err error) {

View File

@@ -62,30 +62,12 @@ func (p *OrgProjectMapping) CurrentSequence(instanceID string) (uint64, error) {
return sequence.CurrentSequence, nil
}
func (p *OrgProjectMapping) EventQuery() (*es_models.SearchQuery, error) {
sequences, err := p.view.GetLatestOrgProjectMappingSequences()
func (p *OrgProjectMapping) EventQuery(instanceIDs ...string) (*es_models.SearchQuery, error) {
sequences, err := p.view.GetLatestOrgProjectMappingSequences(instanceIDs...)
if err != nil {
return nil, err
}
query := es_models.NewSearchQuery()
instances := make([]string, 0)
for _, sequence := range sequences {
for _, instance := range instances {
if sequence.InstanceID == instance {
break
}
}
instances = append(instances, sequence.InstanceID)
query.AddQuery().
AggregateTypeFilter(p.AggregateTypes()...).
LatestSequenceFilter(sequence.CurrentSequence).
InstanceIDFilter(sequence.InstanceID)
}
return query.AddQuery().
AggregateTypeFilter(p.AggregateTypes()...).
LatestSequenceFilter(0).
ExcludedInstanceIDsFilter(instances...).
SearchQuery(), nil
return newSearchQuery(sequences, p.AggregateTypes(), instanceIDs), nil
}
func (p *OrgProjectMapping) Reduce(event *es_models.Event) (err error) {

View File

@@ -66,30 +66,12 @@ func (t *RefreshToken) CurrentSequence(instanceID string) (uint64, error) {
return sequence.CurrentSequence, nil
}
func (t *RefreshToken) EventQuery() (*es_models.SearchQuery, error) {
sequences, err := t.view.GetLatestRefreshTokenSequences()
func (t *RefreshToken) EventQuery(instanceIDs ...string) (*es_models.SearchQuery, error) {
sequences, err := t.view.GetLatestRefreshTokenSequences(instanceIDs...)
if err != nil {
return nil, err
}
query := es_models.NewSearchQuery()
instances := make([]string, 0)
for _, sequence := range sequences {
for _, instance := range instances {
if sequence.InstanceID == instance {
break
}
}
instances = append(instances, sequence.InstanceID)
query.AddQuery().
AggregateTypeFilter(t.AggregateTypes()...).
LatestSequenceFilter(sequence.CurrentSequence).
InstanceIDFilter(sequence.InstanceID)
}
return query.AddQuery().
AggregateTypeFilter(t.AggregateTypes()...).
LatestSequenceFilter(0).
ExcludedInstanceIDsFilter(instances...).
SearchQuery(), nil
return newSearchQuery(sequences, t.AggregateTypes(), instanceIDs), nil
}
func (t *RefreshToken) Reduce(event *es_models.Event) (err error) {

View File

@@ -72,30 +72,12 @@ func (p *Token) CurrentSequence(instanceID string) (uint64, error) {
return sequence.CurrentSequence, nil
}
func (t *Token) EventQuery() (*es_models.SearchQuery, error) {
sequences, err := t.view.GetLatestTokenSequences()
func (t *Token) EventQuery(instanceIDs ...string) (*es_models.SearchQuery, error) {
sequences, err := t.view.GetLatestTokenSequences(instanceIDs...)
if err != nil {
return nil, err
}
query := es_models.NewSearchQuery()
instances := make([]string, 0)
for _, sequence := range sequences {
for _, instance := range instances {
if sequence.InstanceID == instance {
break
}
}
instances = append(instances, sequence.InstanceID)
query.AddQuery().
AggregateTypeFilter(t.AggregateTypes()...).
LatestSequenceFilter(sequence.CurrentSequence).
InstanceIDFilter(sequence.InstanceID)
}
return query.AddQuery().
AggregateTypeFilter(t.AggregateTypes()...).
LatestSequenceFilter(0).
ExcludedInstanceIDsFilter(instances...).
SearchQuery(), nil
return newSearchQuery(sequences, t.AggregateTypes(), instanceIDs), nil
}
func (t *Token) Reduce(event *es_models.Event) (err error) {

View File

@@ -74,30 +74,12 @@ func (u *User) CurrentSequence(instanceID string) (uint64, error) {
return sequence.CurrentSequence, nil
}
func (u *User) EventQuery() (*es_models.SearchQuery, error) {
sequences, err := u.view.GetLatestUserSequences()
func (u *User) EventQuery(instanceIDs ...string) (*es_models.SearchQuery, error) {
sequences, err := u.view.GetLatestUserSequences(instanceIDs...)
if err != nil {
return nil, err
}
query := es_models.NewSearchQuery()
instances := make([]string, 0)
for _, sequence := range sequences {
for _, instance := range instances {
if sequence.InstanceID == instance {
break
}
}
instances = append(instances, sequence.InstanceID)
query.AddQuery().
AggregateTypeFilter(u.AggregateTypes()...).
LatestSequenceFilter(sequence.CurrentSequence).
InstanceIDFilter(sequence.InstanceID)
}
return query.AddQuery().
AggregateTypeFilter(u.AggregateTypes()...).
LatestSequenceFilter(0).
ExcludedInstanceIDsFilter(instances...).
SearchQuery(), nil
return newSearchQuery(sequences, u.AggregateTypes(), instanceIDs), nil
}
func (u *User) Reduce(event *es_models.Event) (err error) {
@@ -176,6 +158,7 @@ func (u *User) ProcessUser(event *es_models.Event) (err error) {
if err != nil {
return err
}
user = &view_model.UserView{}
for _, e := range events {
if err = user.AppendEvent(e); err != nil {
return err
@@ -198,6 +181,7 @@ func (u *User) ProcessUser(event *es_models.Event) (err error) {
if err != nil {
return err
}
user = &view_model.UserView{}
for _, e := range events {
if err = user.AppendEvent(e); err != nil {
return err

View File

@@ -77,30 +77,12 @@ func (i *ExternalIDP) CurrentSequence(instanceID string) (uint64, error) {
return sequence.CurrentSequence, nil
}
func (i *ExternalIDP) EventQuery() (*es_models.SearchQuery, error) {
sequences, err := i.view.GetLatestExternalIDPSequences()
func (i *ExternalIDP) EventQuery(instanceIDs ...string) (*es_models.SearchQuery, error) {
sequences, err := i.view.GetLatestExternalIDPSequences(instanceIDs...)
if err != nil {
return nil, err
}
query := es_models.NewSearchQuery()
instances := make([]string, 0)
for _, sequence := range sequences {
for _, instance := range instances {
if sequence.InstanceID == instance {
break
}
}
instances = append(instances, sequence.InstanceID)
query.AddQuery().
AggregateTypeFilter(i.AggregateTypes()...).
LatestSequenceFilter(sequence.CurrentSequence).
InstanceIDFilter(sequence.InstanceID)
}
return query.AddQuery().
AggregateTypeFilter(i.AggregateTypes()...).
LatestSequenceFilter(0).
ExcludedInstanceIDsFilter(instances...).
SearchQuery(), nil
return newSearchQuery(sequences, i.AggregateTypes(), instanceIDs), nil
}
func (i *ExternalIDP) Reduce(event *es_models.Event) (err error) {

View File

@@ -72,30 +72,12 @@ func (u *UserSession) CurrentSequence(instanceID string) (uint64, error) {
return sequence.CurrentSequence, nil
}
func (u *UserSession) EventQuery() (*models.SearchQuery, error) {
sequences, err := u.view.GetLatestUserSessionSequences()
func (u *UserSession) EventQuery(instanceIDs ...string) (*models.SearchQuery, error) {
sequences, err := u.view.GetLatestUserSessionSequences(instanceIDs...)
if err != nil {
return nil, err
}
query := models.NewSearchQuery()
instances := make([]string, 0)
for _, sequence := range sequences {
for _, instance := range instances {
if sequence.InstanceID == instance {
break
}
}
instances = append(instances, sequence.InstanceID)
query.AddQuery().
AggregateTypeFilter(u.AggregateTypes()...).
LatestSequenceFilter(sequence.CurrentSequence).
InstanceIDFilter(sequence.InstanceID)
}
return query.AddQuery().
AggregateTypeFilter(u.AggregateTypes()...).
LatestSequenceFilter(0).
ExcludedInstanceIDsFilter(instances...).
SearchQuery(), nil
return newSearchQuery(sequences, u.AggregateTypes(), instanceIDs), nil
}
func (u *UserSession) Reduce(event *models.Event) (err error) {

View File

@@ -16,15 +16,17 @@ type SpoolerConfig struct {
BulkLimit uint64
FailureCountUntilSkip uint64
ConcurrentWorkers int
ConcurrentInstances int
Handlers handler.Configs
}
func StartSpooler(c SpoolerConfig, es v1.Eventstore, view *view.View, client *sql.DB, systemDefaults sd.SystemDefaults, queries *query.Queries) *spooler.Spooler {
spoolerConfig := spooler.Config{
Eventstore: es,
Locker: &locker{dbClient: client},
ConcurrentWorkers: c.ConcurrentWorkers,
ViewHandlers: handler.Register(c.Handlers, c.BulkLimit, c.FailureCountUntilSkip, view, es, systemDefaults, queries),
Eventstore: es,
Locker: &locker{dbClient: client},
ConcurrentWorkers: c.ConcurrentWorkers,
ConcurrentInstances: c.ConcurrentInstances,
ViewHandlers: handler.Register(c.Handlers, c.BulkLimit, c.FailureCountUntilSkip, view, es, systemDefaults, queries),
}
spool := spoolerConfig.New()
spool.Start()

View File

@@ -60,8 +60,8 @@ func (v *View) GetLatestExternalIDPSequence(instanceID string) (*global_view.Cur
return v.latestSequence(externalIDPTable, instanceID)
}
func (v *View) GetLatestExternalIDPSequences() ([]*global_view.CurrentSequence, error) {
return v.latestSequences(externalIDPTable)
func (v *View) GetLatestExternalIDPSequences(instanceIDs ...string) ([]*global_view.CurrentSequence, error) {
return v.latestSequences(externalIDPTable, instanceIDs...)
}
func (v *View) ProcessedExternalIDPSequence(event *models.Event) error {

View File

@@ -45,8 +45,8 @@ func (v *View) GetLatestIDPConfigSequence(instanceID string) (*global_view.Curre
return v.latestSequence(idpConfigTable, instanceID)
}
func (v *View) GetLatestIDPConfigSequences() ([]*global_view.CurrentSequence, error) {
return v.latestSequences(idpConfigTable)
func (v *View) GetLatestIDPConfigSequences(instanceIDs ...string) ([]*global_view.CurrentSequence, error) {
return v.latestSequences(idpConfigTable, instanceIDs...)
}
func (v *View) ProcessedIDPConfigSequence(event *models.Event) error {

View File

@@ -65,8 +65,8 @@ func (v *View) GetLatestIDPProviderSequence(instanceID string) (*global_view.Cur
return v.latestSequence(idpProviderTable, instanceID)
}
func (v *View) GetLatestIDPProviderSequences() ([]*global_view.CurrentSequence, error) {
return v.latestSequences(idpProviderTable)
func (v *View) GetLatestIDPProviderSequences(instanceIDs ...string) ([]*global_view.CurrentSequence, error) {
return v.latestSequences(idpProviderTable, instanceIDs...)
}
func (v *View) ProcessedIDPProviderSequence(event *models.Event) error {

View File

@@ -44,8 +44,8 @@ func (v *View) GetLatestOrgProjectMappingSequence(instanceID string) (*repositor
return v.latestSequence(orgPrgojectMappingTable, instanceID)
}
func (v *View) GetLatestOrgProjectMappingSequences() ([]*repository.CurrentSequence, error) {
return v.latestSequences(orgPrgojectMappingTable)
func (v *View) GetLatestOrgProjectMappingSequences(instanceIDs ...string) ([]*repository.CurrentSequence, error) {
return v.latestSequences(orgPrgojectMappingTable, instanceIDs...)
}
func (v *View) ProcessedOrgProjectMappingSequence(event *models.Event) error {

View File

@@ -69,8 +69,8 @@ func (v *View) GetLatestRefreshTokenSequence(instanceID string) (*repository.Cur
return v.latestSequence(refreshTokenTable, instanceID)
}
func (v *View) GetLatestRefreshTokenSequences() ([]*repository.CurrentSequence, error) {
return v.latestSequences(refreshTokenTable)
func (v *View) GetLatestRefreshTokenSequences(instanceIDs ...string) ([]*repository.CurrentSequence, error) {
return v.latestSequences(refreshTokenTable, instanceIDs...)
}
func (v *View) ProcessedRefreshTokenSequence(event *models.Event) error {

View File

@@ -19,8 +19,8 @@ func (v *View) latestSequence(viewName, instanceID string) (*repository.CurrentS
return repository.LatestSequence(v.Db, sequencesTable, viewName, instanceID)
}
func (v *View) latestSequences(viewName string) ([]*repository.CurrentSequence, error) {
return repository.LatestSequences(v.Db, sequencesTable, viewName)
func (v *View) latestSequences(viewName string, instanceIDs ...string) ([]*repository.CurrentSequence, error) {
return repository.LatestSequences(v.Db, sequencesTable, viewName, instanceIDs...)
}
func (v *View) updateSpoolerRunSequence(viewName string) error {

View File

@@ -80,8 +80,8 @@ func (v *View) GetLatestTokenSequence(instanceID string) (*repository.CurrentSeq
return v.latestSequence(tokenTable, instanceID)
}
func (v *View) GetLatestTokenSequences() ([]*repository.CurrentSequence, error) {
return v.latestSequences(tokenTable)
func (v *View) GetLatestTokenSequences(instanceIDs ...string) ([]*repository.CurrentSequence, error) {
return v.latestSequences(tokenTable, instanceIDs...)
}
func (v *View) ProcessedTokenSequence(event *models.Event) error {

View File

@@ -143,8 +143,8 @@ func (v *View) GetLatestUserSequence(instanceID string) (*repository.CurrentSequ
return v.latestSequence(userTable, instanceID)
}
func (v *View) GetLatestUserSequences() ([]*repository.CurrentSequence, error) {
return v.latestSequences(userTable)
func (v *View) GetLatestUserSequences(instanceIDs ...string) ([]*repository.CurrentSequence, error) {
return v.latestSequences(userTable, instanceIDs...)
}
func (v *View) ProcessedUserSequence(event *models.Event) error {

View File

@@ -60,8 +60,8 @@ func (v *View) GetLatestUserSessionSequence(instanceID string) (*repository.Curr
return v.latestSequence(userSessionTable, instanceID)
}
func (v *View) GetLatestUserSessionSequences() ([]*repository.CurrentSequence, error) {
return v.latestSequences(userSessionTable)
func (v *View) GetLatestUserSessionSequences(instanceIDs ...string) ([]*repository.CurrentSequence, error) {
return v.latestSequences(userSessionTable, instanceIDs...)
}
func (v *View) ProcessedUserSessionSequence(event *models.Event) error {

View File

@@ -186,6 +186,15 @@ func (es *Eventstore) LatestSequence(ctx context.Context, queryFactory *SearchQu
return es.repo.LatestSequence(ctx, query)
}
//InstanceIDs returns the instance ids found by the search query
func (es *Eventstore) InstanceIDs(ctx context.Context, queryFactory *SearchQueryBuilder) ([]string, error) {
query, err := queryFactory.build(authz.GetInstance(ctx).InstanceID())
if err != nil {
return nil, err
}
return es.repo.InstanceIDs(ctx, query)
}
type QueryReducer interface {
reducer
//Query returns the SearchQueryFactory for the events needed in reducer

View File

@@ -688,10 +688,11 @@ func TestEventstore_aggregatesToEvents(t *testing.T) {
}
type testRepo struct {
events []*repository.Event
sequence uint64
err error
t *testing.T
events []*repository.Event
sequence uint64
instances []string
err error
t *testing.T
}
func (repo *testRepo) Health(ctx context.Context) error {
@@ -735,6 +736,13 @@ func (repo *testRepo) LatestSequence(ctx context.Context, queryFactory *reposito
return repo.sequence, nil
}
func (repo *testRepo) InstanceIDs(ctx context.Context, queryFactory *repository.SearchQuery) ([]string, error) {
if repo.err != nil {
return nil, repo.err
}
return repo.instances, nil
}
func TestEventstore_Push(t *testing.T) {
type args struct {
events []Command

View File

@@ -6,12 +6,14 @@ import (
"strconv"
"strings"
"github.com/lib/pq"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
)
const (
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 FOR UPDATE`
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2) FOR UPDATE`
updateCurrentSequencesStmtFormat = `UPSERT INTO %s (projection_name, aggregate_type, current_sequence, instance_id, timestamp) VALUES `
)
@@ -22,8 +24,8 @@ type instanceSequence struct {
sequence uint64
}
func (h *StatementHandler) currentSequences(ctx context.Context, query func(context.Context, string, ...interface{}) (*sql.Rows, error)) (currentSequences, error) {
rows, err := query(ctx, h.currentSequenceStmt, h.ProjectionName)
func (h *StatementHandler) currentSequences(ctx context.Context, query func(context.Context, string, ...interface{}) (*sql.Rows, error), instanceIDs []string) (currentSequences, error) {
rows, err := query(ctx, h.currentSequenceStmt, h.ProjectionName, pq.StringArray(instanceIDs))
if err != nil {
return nil, err
}

View File

@@ -8,6 +8,7 @@ import (
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/lib/pq"
"github.com/zitadel/zitadel/internal/eventstore"
)
@@ -123,34 +124,40 @@ func expectSavePointRelease() func(sqlmock.Sqlmock) {
}
}
func expectCurrentSequence(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
func expectCurrentSequence(tableName, projection string, seq uint64, aggregateType string, instanceIDs []string) func(sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"})
for _, instanceID := range instanceIDs {
rows.AddRow(seq, aggregateType, instanceID)
}
return func(m sqlmock.Sqlmock) {
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
WithArgs(
projection,
pq.StringArray(instanceIDs),
).
WillReturnRows(
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}).
AddRow(seq, aggregateType, instanceID),
rows,
)
}
}
func expectCurrentSequenceErr(tableName, projection string, err error) func(sqlmock.Sqlmock) {
func expectCurrentSequenceErr(tableName, projection string, instanceIDs []string, err error) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
WithArgs(
projection,
pq.StringArray(instanceIDs),
).
WillReturnError(err)
}
}
func expectCurrentSequenceNoRows(tableName, projection string) func(sqlmock.Sqlmock) {
func expectCurrentSequenceNoRows(tableName, projection string, instanceIDs []string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
WithArgs(
projection,
pq.StringArray(instanceIDs),
).
WillReturnRows(
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}),
@@ -158,11 +165,12 @@ func expectCurrentSequenceNoRows(tableName, projection string) func(sqlmock.Sqlm
}
}
func expectCurrentSequenceScanErr(tableName, projection string) func(sqlmock.Sqlmock) {
func expectCurrentSequenceScanErr(tableName, projection string, instanceIDs []string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 FOR UPDATE`).
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
WithArgs(
projection,
pq.StringArray(instanceIDs),
).
WillReturnRows(
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}).
@@ -286,12 +294,34 @@ func expectLock(lockTable, workerName string, d time.Duration, instanceID string
` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\)`+
` ON CONFLICT \(projection_name, instance_id\)`+
` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = \$4 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$5\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
WithArgs(
workerName,
float64(d),
projectionName,
instanceID,
pq.StringArray{instanceID},
).
WillReturnResult(
sqlmock.NewResult(1, 1),
)
}
}
func expectLockMultipleInstances(lockTable, workerName string, d time.Duration, instanceID1, instanceID2 string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec(`INSERT INTO `+lockTable+
` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\), \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$5\)`+
` ON CONFLICT \(projection_name, instance_id\)`+
` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$6\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
WithArgs(
workerName,
float64(d),
projectionName,
instanceID1,
instanceID2,
pq.StringArray{instanceID1, instanceID2},
).
WillReturnResult(
sqlmock.NewResult(1, 1),
@@ -305,12 +335,13 @@ func expectLockNoRows(lockTable, workerName string, d time.Duration, instanceID
` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\)`+
` ON CONFLICT \(projection_name, instance_id\)`+
` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = \$4 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$5\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
WithArgs(
workerName,
float64(d),
projectionName,
instanceID,
pq.StringArray{instanceID},
).
WillReturnResult(driver.ResultNoRows)
}
@@ -322,12 +353,13 @@ func expectLockErr(lockTable, workerName string, d time.Duration, instanceID str
` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\)`+
` ON CONFLICT \(projection_name, instance_id\)`+
` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = \$4 AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$5\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
WithArgs(
workerName,
float64(d),
projectionName,
instanceID,
pq.StringArray{instanceID},
).
WillReturnError(err)
}

View File

@@ -10,7 +10,6 @@ import (
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
var (
@@ -75,84 +74,62 @@ func NewStatementHandler(
bulkLimit: config.BulkLimit,
Locker: NewLocker(config.Client, config.LockTable, config.ProjectionHandlerConfig.ProjectionName),
}
h.ProjectionHandler = handler.NewProjectionHandler(config.ProjectionHandlerConfig, h.reduce, h.Update, h.SearchQuery)
h.ProjectionHandler = handler.NewProjectionHandler(ctx, config.ProjectionHandlerConfig, h.reduce, h.Update, h.SearchQuery, h.Lock, h.Unlock)
err := h.Init(ctx, config.InitCheck)
logging.OnError(err).Fatal("unable to initialize projections")
go h.Process(
ctx,
h.reduce,
h.Update,
h.Lock,
h.Unlock,
h.SearchQuery,
)
h.Subscribe(h.aggregates...)
return h
}
func (h *StatementHandler) TriggerBulk(ctx context.Context) {
ctx, span := tracing.NewSpan(ctx)
var err error
defer span.EndWithError(err)
err = h.ProjectionHandler.TriggerBulk(ctx, h.Lock, h.Unlock)
logging.OnError(err).WithField("projection", h.ProjectionName).Warn("unable to trigger bulk")
}
func (h *StatementHandler) SearchQuery(ctx context.Context) (*eventstore.SearchQueryBuilder, uint64, error) {
sequences, err := h.currentSequences(ctx, h.client.QueryContext)
func (h *StatementHandler) SearchQuery(ctx context.Context, instanceIDs []string) (*eventstore.SearchQueryBuilder, uint64, error) {
sequences, err := h.currentSequences(ctx, h.client.QueryContext, instanceIDs)
if err != nil {
return nil, 0, err
}
queryBuilder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).Limit(h.bulkLimit)
for _, aggregateType := range h.aggregates {
instances := make([]string, 0)
for _, sequence := range sequences[aggregateType] {
instances = appendToIgnoredInstances(instances, sequence.instanceID)
for _, instanceID := range instanceIDs {
var seq uint64
for _, sequence := range sequences[aggregateType] {
if sequence.instanceID == instanceID {
seq = sequence.sequence
break
}
}
queryBuilder.
AddQuery().
AggregateTypes(aggregateType).
SequenceGreater(sequence.sequence).
InstanceID(sequence.instanceID)
SequenceGreater(seq).
InstanceID(instanceID)
}
queryBuilder.
AddQuery().
AggregateTypes(aggregateType).
SequenceGreater(0).
ExcludedInstanceID(instances...)
}
return queryBuilder, h.bulkLimit, nil
}
func appendToIgnoredInstances(instances []string, id string) []string {
for _, instance := range instances {
if instance == id {
return instances
}
}
return append(instances, id)
}
//Update implements handler.Update
func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statement, reduce handler.Reduce) (unexecutedStmts []*handler.Statement, err error) {
func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statement, reduce handler.Reduce) (index int, err error) {
if len(stmts) == 0 {
return nil, nil
return -1, nil
}
instanceIDs := make([]string, 0, len(stmts))
for _, stmt := range stmts {
instanceIDs = appendToInstanceIDs(instanceIDs, stmt.InstanceID)
}
tx, err := h.client.BeginTx(ctx, nil)
if err != nil {
return stmts, errors.ThrowInternal(err, "CRDB-e89Gq", "begin failed")
return -1, errors.ThrowInternal(err, "CRDB-e89Gq", "begin failed")
}
sequences, err := h.currentSequences(ctx, tx.QueryContext)
sequences, err := h.currentSequences(ctx, tx.QueryContext, instanceIDs)
if err != nil {
tx.Rollback()
return stmts, err
return -1, err
}
//checks for events between create statement and current sequence
@@ -162,7 +139,7 @@ func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statemen
previousStmts, err := h.fetchPreviousStmts(ctx, tx, stmts[0].Sequence, stmts[0].InstanceID, sequences, reduce)
if err != nil {
tx.Rollback()
return stmts, err
return -1, err
}
stmts = append(previousStmts, stmts...)
}
@@ -173,27 +150,19 @@ func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statemen
err = h.updateCurrentSequences(tx, sequences)
if err != nil {
tx.Rollback()
return stmts, err
return -1, err
}
}
if err = tx.Commit(); err != nil {
return stmts, err
return -1, err
}
if lastSuccessfulIdx == -1 && len(stmts) > 0 {
return stmts, handler.ErrSomeStmtsFailed
if lastSuccessfulIdx < len(stmts)-1 {
return lastSuccessfulIdx, handler.ErrSomeStmtsFailed
}
unexecutedStmts = make([]*handler.Statement, len(stmts)-(lastSuccessfulIdx+1))
copy(unexecutedStmts, stmts[lastSuccessfulIdx+1:])
stmts = nil
if len(unexecutedStmts) > 0 {
return unexecutedStmts, handler.ErrSomeStmtsFailed
}
return unexecutedStmts, nil
return lastSuccessfulIdx, nil
}
func (h *StatementHandler) fetchPreviousStmts(ctx context.Context, tx *sql.Tx, stmtSeq uint64, instanceID string, sequences currentSequences, reduce handler.Reduce) (previousStmts []*handler.Statement, err error) {
@@ -316,3 +285,12 @@ func updateSequences(sequences currentSequences, stmt *handler.Statement) {
sequence: stmt.Sequence,
})
}
func appendToInstanceIDs(instances []string, id string) []string {
for _, instance := range instances {
if instance == id {
return instances
}
}
return append(instances, id)
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"reflect"
"testing"
"time"
@@ -61,9 +62,13 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
aggregates []eventstore.AggregateType
bulkLimit uint64
}
type args struct {
instanceIDs []string
}
tests := []struct {
name string
fields fields
args args
want want
}{
{
@@ -74,13 +79,16 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
aggregates: []eventstore.AggregateType{"testAgg"},
bulkLimit: 5,
},
args: args{
instanceIDs: []string{"instanceID1"},
},
want: want{
limit: 0,
isErr: func(err error) bool {
return errors.Is(err, sql.ErrTxDone)
},
expectations: []mockExpectation{
expectCurrentSequenceErr("my_sequences", "my_projection", sql.ErrTxDone),
expectCurrentSequenceErr("my_sequences", "my_projection", []string{"instanceID1"}, sql.ErrTxDone),
},
SearchQueryBuilder: nil,
},
@@ -93,24 +101,56 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
aggregates: []eventstore.AggregateType{"testAgg"},
bulkLimit: 5,
},
args: args{
instanceIDs: []string{"instanceID1"},
},
want: want{
limit: 5,
isErr: func(err error) bool {
return err == nil
},
expectations: []mockExpectation{
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID1"}),
},
SearchQueryBuilder: eventstore.
NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes("testAgg").
SequenceGreater(5).
InstanceID("instanceID").
InstanceID("instanceID1").
Builder().
Limit(5),
},
},
{
name: "multiple instances",
fields: fields{
sequenceTable: "my_sequences",
projectionName: "my_projection",
aggregates: []eventstore.AggregateType{"testAgg"},
bulkLimit: 5,
},
args: args{
instanceIDs: []string{"instanceID1", "instanceID2"},
},
want: want{
limit: 5,
isErr: func(err error) bool {
return err == nil
},
expectations: []mockExpectation{
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID1", "instanceID2"}),
},
SearchQueryBuilder: eventstore.
NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes("testAgg").
SequenceGreater(5).
InstanceID("instanceID1").
Or().
AggregateTypes("testAgg").
SequenceGreater(0).
ExcludedInstanceID("instanceID").
SequenceGreater(5).
InstanceID("instanceID2").
Builder().
Limit(5),
},
@@ -140,7 +180,7 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
expectation(mock)
}
query, limit, err := h.SearchQuery(context.Background())
query, limit, err := h.SearchQuery(context.Background(), tt.args.instanceIDs)
if !tt.want.isErr(err) {
t.Errorf("ProjectionHandler.prepareBulkStmts() error = %v", err)
return
@@ -211,13 +251,14 @@ func TestStatementHandler_Update(t *testing.T) {
aggregateType: "agg",
sequence: 6,
previousSequence: 0,
instanceID: "instanceID",
}),
},
},
want: want{
expectations: []mockExpectation{
expectBegin(),
expectCurrentSequenceErr("my_sequences", "my_projection", sql.ErrTxDone),
expectCurrentSequenceErr("my_sequences", "my_projection", []string{"instanceID"}, sql.ErrTxDone),
expectRollback(),
},
isErr: func(err error) bool {
@@ -241,13 +282,14 @@ func TestStatementHandler_Update(t *testing.T) {
aggregateType: "agg",
sequence: 6,
previousSequence: 0,
instanceID: "instanceID",
}),
},
},
want: want{
expectations: []mockExpectation{
expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
expectRollback(),
},
isErr: func(err error) bool {
@@ -272,6 +314,7 @@ func TestStatementHandler_Update(t *testing.T) {
aggregateType: "testAgg",
sequence: 7,
previousSequence: 6,
instanceID: "instanceID",
},
[]handler.Column{
{
@@ -284,7 +327,7 @@ func TestStatementHandler_Update(t *testing.T) {
want: want{
expectations: []mockExpectation{
expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
expectCommit(),
},
isErr: func(err error) bool {
@@ -322,7 +365,7 @@ func TestStatementHandler_Update(t *testing.T) {
want: want{
expectations: []mockExpectation{
expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "agg", "instanceID"),
expectCurrentSequence("my_sequences", "my_projection", 5, "agg", []string{"instanceID"}),
expectSavePoint(),
expectCreate("my_projection", []string{"col"}, []string{"$1"}),
expectSavePointRelease(),
@@ -364,7 +407,7 @@ func TestStatementHandler_Update(t *testing.T) {
want: want{
expectations: []mockExpectation{
expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "agg", "instanceID"),
expectCurrentSequence("my_sequences", "my_projection", 5, "agg", []string{"instanceID"}),
expectSavePoint(),
expectCreate("my_projection", []string{"col"}, []string{"$1"}),
expectSavePointRelease(),
@@ -399,7 +442,7 @@ func TestStatementHandler_Update(t *testing.T) {
want: want{
expectations: []mockExpectation{
expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"),
expectCommit(),
},
@@ -431,7 +474,7 @@ func TestStatementHandler_Update(t *testing.T) {
want: want{
expectations: []mockExpectation{
expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"),
expectCommit(),
},
@@ -470,13 +513,14 @@ func TestStatementHandler_Update(t *testing.T) {
want: want{
expectations: []mockExpectation{
expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", "instanceID"),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"),
expectCommit(),
},
isErr: func(err error) bool {
return errors.Is(err, nil)
},
stmtsLen: 1,
},
},
}
@@ -488,17 +532,18 @@ func TestStatementHandler_Update(t *testing.T) {
}
defer client.Close()
h := NewStatementHandler(context.Background(), StatementHandlerConfig{
ProjectionHandlerConfig: handler.ProjectionHandlerConfig{
ProjectionName: "my_projection",
HandlerConfig: handler.HandlerConfig{
h := &StatementHandler{
ProjectionHandler: &handler.ProjectionHandler{
Handler: handler.Handler{
Eventstore: tt.fields.eventstore,
},
RequeueEvery: 0,
ProjectionName: "my_projection",
},
SequenceTable: "my_sequences",
Client: client,
})
sequenceTable: "my_sequences",
currentSequenceStmt: fmt.Sprintf(currentSequenceStmtFormat, "my_sequences"),
updateSequencesBaseStmt: fmt.Sprintf(updateCurrentSequencesStmtFormat, "my_sequences"),
client: client,
}
h.aggregates = tt.fields.aggregates
@@ -506,12 +551,12 @@ func TestStatementHandler_Update(t *testing.T) {
expectation(mock)
}
stmts, err := h.Update(tt.args.ctx, tt.args.stmts, tt.args.reduce)
index, err := h.Update(tt.args.ctx, tt.args.stmts, tt.args.reduce)
if !tt.want.isErr(err) {
t.Errorf("StatementHandler.Update() error = %v", err)
}
if err == nil && tt.want.stmtsLen != len(stmts) {
t.Errorf("wrong stmts length: want: %d got %d", tt.want.stmtsLen, len(stmts))
if err == nil && tt.want.stmtsLen != index {
t.Errorf("wrong stmts length: want: %d got %d", tt.want.stmtsLen, index)
}
mock.MatchExpectationsInOrder(true)
@@ -696,17 +741,12 @@ func TestProjectionHandler_fetchPreviousStmts(t *testing.T) {
h := &StatementHandler{
aggregates: tt.fields.aggregates,
}
h.ProjectionHandler = handler.NewProjectionHandler(handler.ProjectionHandlerConfig{
HandlerConfig: handler.HandlerConfig{
h.ProjectionHandler = &handler.ProjectionHandler{
Handler: handler.Handler{
Eventstore: tt.fields.eventstore,
},
ProjectionName: "my_projection",
RequeueEvery: 0,
},
h.reduce,
h.Update,
h.SearchQuery,
)
}
stmts, err := h.fetchPreviousStmts(tt.args.ctx, nil, tt.args.stmtSeq, "", tt.args.sequences, tt.args.reduce)
if !tt.want.isErr(err) {
t.Errorf("ProjectionHandler.prepareBulkStmts() error = %v", err)
@@ -1311,7 +1351,8 @@ func TestStatementHandler_currentSequence(t *testing.T) {
aggregates []eventstore.AggregateType
}
type args struct {
stmt handler.Statement
stmt handler.Statement
instanceIDs []string
}
type want struct {
expectations []mockExpectation
@@ -1338,7 +1379,7 @@ func TestStatementHandler_currentSequence(t *testing.T) {
return errors.Is(err, sql.ErrConnDone)
},
expectations: []mockExpectation{
expectCurrentSequenceErr("my_table", "my_projection", sql.ErrConnDone),
expectCurrentSequenceErr("my_table", "my_projection", nil, sql.ErrConnDone),
},
},
},
@@ -1350,14 +1391,15 @@ func TestStatementHandler_currentSequence(t *testing.T) {
aggregates: []eventstore.AggregateType{"agg"},
},
args: args{
stmt: handler.Statement{},
stmt: handler.Statement{},
instanceIDs: []string{"instanceID"},
},
want: want{
isErr: func(err error) bool {
return errors.Is(err, nil)
},
expectations: []mockExpectation{
expectCurrentSequenceNoRows("my_table", "my_projection"),
expectCurrentSequenceNoRows("my_table", "my_projection", []string{"instanceID"}),
},
sequences: currentSequences{},
},
@@ -1370,14 +1412,15 @@ func TestStatementHandler_currentSequence(t *testing.T) {
aggregates: []eventstore.AggregateType{"agg"},
},
args: args{
stmt: handler.Statement{},
stmt: handler.Statement{},
instanceIDs: []string{"instanceID"},
},
want: want{
isErr: func(err error) bool {
return errors.Is(err, sql.ErrTxDone)
},
expectations: []mockExpectation{
expectCurrentSequenceScanErr("my_table", "my_projection"),
expectCurrentSequenceScanErr("my_table", "my_projection", []string{"instanceID"}),
},
sequences: currentSequences{},
},
@@ -1390,14 +1433,15 @@ func TestStatementHandler_currentSequence(t *testing.T) {
aggregates: []eventstore.AggregateType{"agg"},
},
args: args{
stmt: handler.Statement{},
stmt: handler.Statement{},
instanceIDs: []string{"instanceID"},
},
want: want{
isErr: func(err error) bool {
return errors.Is(err, nil)
},
expectations: []mockExpectation{
expectCurrentSequence("my_table", "my_projection", 5, "agg", "instanceID"),
expectCurrentSequence("my_table", "my_projection", 5, "agg", []string{"instanceID"}),
},
sequences: currentSequences{
"agg": []*instanceSequence{
@@ -1409,15 +1453,48 @@ func TestStatementHandler_currentSequence(t *testing.T) {
},
},
},
{
name: "multiple found",
fields: fields{
sequenceTable: "my_table",
projectionName: "my_projection",
aggregates: []eventstore.AggregateType{"agg"},
},
args: args{
stmt: handler.Statement{},
instanceIDs: []string{"instanceID1", "instanceID2"},
},
want: want{
isErr: func(err error) bool {
return errors.Is(err, nil)
},
expectations: []mockExpectation{
expectCurrentSequence("my_table", "my_projection", 5, "agg", []string{"instanceID1", "instanceID2"}),
},
sequences: currentSequences{
"agg": []*instanceSequence{
{
sequence: 5,
instanceID: "instanceID1",
},
{
sequence: 5,
instanceID: "instanceID2",
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := NewStatementHandler(context.Background(), StatementHandlerConfig{
ProjectionHandlerConfig: handler.ProjectionHandlerConfig{
h := &StatementHandler{
ProjectionHandler: &handler.ProjectionHandler{
ProjectionName: tt.fields.projectionName,
},
SequenceTable: tt.fields.sequenceTable,
})
sequenceTable: tt.fields.sequenceTable,
currentSequenceStmt: fmt.Sprintf(currentSequenceStmtFormat, tt.fields.sequenceTable),
}
h.aggregates = tt.fields.aggregates
@@ -1440,7 +1517,7 @@ func TestStatementHandler_currentSequence(t *testing.T) {
t.Fatalf("unexpected err in begin: %v", err)
}
seq, err := h.currentSequences(context.Background(), tx.QueryContext)
seq, err := h.currentSequences(context.Background(), tx.QueryContext, tt.args.instanceIDs)
if !tt.want.isErr(err) {
t.Errorf("unexpected error: %v", err)
}
@@ -1615,12 +1692,13 @@ func TestStatementHandler_updateCurrentSequence(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := NewStatementHandler(context.Background(), StatementHandlerConfig{
ProjectionHandlerConfig: handler.ProjectionHandlerConfig{
h := &StatementHandler{
ProjectionHandler: &handler.ProjectionHandler{
ProjectionName: tt.fields.projectionName,
},
SequenceTable: tt.fields.sequenceTable,
})
sequenceTable: tt.fields.sequenceTable,
updateSequencesBaseStmt: fmt.Sprintf(updateCurrentSequencesStmtFormat, tt.fields.sequenceTable),
}
h.aggregates = tt.fields.aggregates

View File

@@ -4,8 +4,11 @@ import (
"context"
"database/sql"
"fmt"
"strconv"
"strings"
"time"
"github.com/lib/pq"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/errors"
@@ -14,20 +17,20 @@ import (
const (
lockStmtFormat = "INSERT INTO %[1]s" +
" (locker_id, locked_until, projection_name, instance_id) VALUES ($1, now()+$2::INTERVAL, $3, $4)" +
" (locker_id, locked_until, projection_name, instance_id) VALUES %[2]s" +
" ON CONFLICT (projection_name, instance_id)" +
" DO UPDATE SET locker_id = $1, locked_until = now()+$2::INTERVAL" +
" WHERE %[1]s.projection_name = $3 AND %[1]s.instance_id = $4 AND (%[1]s.locker_id = $1 OR %[1]s.locked_until < now())"
" WHERE %[1]s.projection_name = $3 AND %[1]s.instance_id = ANY ($%[3]d) AND (%[1]s.locker_id = $1 OR %[1]s.locked_until < now())"
)
type Locker interface {
Lock(ctx context.Context, lockDuration time.Duration, instanceID string) <-chan error
Unlock(instanceID string) error
Lock(ctx context.Context, lockDuration time.Duration, instanceIDs ...string) <-chan error
Unlock(instanceIDs ...string) error
}
type locker struct {
client *sql.DB
lockStmt string
lockStmt func(values string, instances int) string
workerName string
projectionName string
}
@@ -36,25 +39,27 @@ func NewLocker(client *sql.DB, lockTable, projectionName string) Locker {
workerName, err := id.SonyFlakeGenerator().Next()
logging.OnError(err).Panic("unable to generate lockID")
return &locker{
client: client,
lockStmt: fmt.Sprintf(lockStmtFormat, lockTable),
client: client,
lockStmt: func(values string, instances int) string {
return fmt.Sprintf(lockStmtFormat, lockTable, values, instances)
},
workerName: workerName,
projectionName: projectionName,
}
}
func (h *locker) Lock(ctx context.Context, lockDuration time.Duration, instanceID string) <-chan error {
func (h *locker) Lock(ctx context.Context, lockDuration time.Duration, instanceIDs ...string) <-chan error {
errs := make(chan error)
go h.handleLock(ctx, errs, lockDuration, instanceID)
go h.handleLock(ctx, errs, lockDuration, instanceIDs...)
return errs
}
func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration time.Duration, instanceID string) {
func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration time.Duration, instanceIDs ...string) {
renewLock := time.NewTimer(0)
for {
select {
case <-renewLock.C:
errs <- h.renewLock(ctx, lockDuration, instanceID)
errs <- h.renewLock(ctx, lockDuration, instanceIDs...)
//refresh the lock 500ms before it times out. 500ms should be enough for one transaction
renewLock.Reset(lockDuration - (500 * time.Millisecond))
case <-ctx.Done():
@@ -65,24 +70,38 @@ func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration t
}
}
func (h *locker) renewLock(ctx context.Context, lockDuration time.Duration, instanceID string) error {
//the unit of crdb interval is seconds (https://www.cockroachlabs.com/docs/stable/interval.html).
res, err := h.client.ExecContext(ctx, h.lockStmt, h.workerName, lockDuration.Seconds(), h.projectionName, instanceID)
func (h *locker) renewLock(ctx context.Context, lockDuration time.Duration, instanceIDs ...string) error {
lockStmt, values := h.lockStatement(lockDuration, instanceIDs)
res, err := h.client.ExecContext(ctx, lockStmt, values...)
if err != nil {
return errors.ThrowInternal(err, "CRDB-uaDoR", "unable to execute lock")
}
if rows, _ := res.RowsAffected(); rows == 0 {
return errors.ThrowAlreadyExists(nil, "CRDB-mmi4J", "projection already locked")
}
return nil
}
func (h *locker) Unlock(instanceID string) error {
_, err := h.client.Exec(h.lockStmt, h.workerName, float64(0), h.projectionName, instanceID)
func (h *locker) Unlock(instanceIDs ...string) error {
lockStmt, values := h.lockStatement(0, instanceIDs)
_, err := h.client.Exec(lockStmt, values...)
if err != nil {
return errors.ThrowUnknown(err, "CRDB-JjfwO", "unlock failed")
}
return nil
}
func (h *locker) lockStatement(lockDuration time.Duration, instanceIDs []string) (string, []interface{}) {
valueQueries := make([]string, len(instanceIDs))
values := make([]interface{}, len(instanceIDs)+4)
values[0] = h.workerName
//the unit of crdb interval is seconds (https://www.cockroachlabs.com/docs/stable/interval.html).
values[1] = lockDuration.Seconds()
values[2] = h.projectionName
for i, instanceID := range instanceIDs {
valueQueries[i] = "($1, now()+$2::INTERVAL, $3, $" + strconv.Itoa(i+4) + ")"
values[i+3] = instanceID
}
values[len(values)-1] = pq.StringArray(instanceIDs)
return h.lockStmt(strings.Join(valueQueries, ", "), len(values)), values
}

View File

@@ -32,7 +32,7 @@ func TestStatementHandler_handleLock(t *testing.T) {
lockDuration time.Duration
ctx context.Context
errMock *errsMock
instanceID string
instanceIDs []string
}
tests := []struct {
name string
@@ -56,7 +56,7 @@ func TestStatementHandler_handleLock(t *testing.T) {
successfulIters: 2,
shouldErr: true,
},
instanceID: "instanceID",
instanceIDs: []string{"instanceID"},
},
},
{
@@ -74,7 +74,25 @@ func TestStatementHandler_handleLock(t *testing.T) {
errs: make(chan error),
successfulIters: 2,
},
instanceID: "instanceID",
instanceIDs: []string{"instanceID"},
},
},
{
name: "success with multiple",
want: want{
expectations: []mockExpectation{
expectLockMultipleInstances(lockTable, workerName, 2, "instanceID1", "instanceID2"),
expectLockMultipleInstances(lockTable, workerName, 2, "instanceID1", "instanceID2"),
},
},
args: args{
lockDuration: 2 * time.Second,
ctx: context.Background(),
errMock: &errsMock{
errs: make(chan error),
successfulIters: 2,
},
instanceIDs: []string{"instanceID1", "instanceID2"},
},
},
}
@@ -88,7 +106,9 @@ func TestStatementHandler_handleLock(t *testing.T) {
projectionName: projectionName,
client: client,
workerName: workerName,
lockStmt: fmt.Sprintf(lockStmtFormat, lockTable),
lockStmt: func(values string, instances int) string {
return fmt.Sprintf(lockStmtFormat, lockTable, values, instances)
},
}
for _, expectation := range tt.want.expectations {
@@ -99,7 +119,7 @@ func TestStatementHandler_handleLock(t *testing.T) {
go tt.args.errMock.handleErrs(t, cancel)
go h.handleLock(ctx, tt.args.errMock.errs, tt.args.lockDuration, tt.args.instanceID)
go h.handleLock(ctx, tt.args.errMock.errs, tt.args.lockDuration, tt.args.instanceIDs...)
<-ctx.Done()
@@ -118,7 +138,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
}
type args struct {
lockDuration time.Duration
instanceID string
instanceIDs []string
}
tests := []struct {
name string
@@ -137,7 +157,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
},
args: args{
lockDuration: 1 * time.Second,
instanceID: "instanceID",
instanceIDs: []string{"instanceID"},
},
},
{
@@ -152,7 +172,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
},
args: args{
lockDuration: 2 * time.Second,
instanceID: "instanceID",
instanceIDs: []string{"instanceID"},
},
},
{
@@ -167,7 +187,22 @@ func TestStatementHandler_renewLock(t *testing.T) {
},
args: args{
lockDuration: 3 * time.Second,
instanceID: "instanceID",
instanceIDs: []string{"instanceID"},
},
},
{
name: "success with multiple",
want: want{
expectations: []mockExpectation{
expectLockMultipleInstances(lockTable, workerName, 3, "instanceID1", "instanceID2"),
},
isErr: func(err error) bool {
return errors.Is(err, nil)
},
},
args: args{
lockDuration: 3 * time.Second,
instanceIDs: []string{"instanceID1", "instanceID2"},
},
},
}
@@ -181,14 +216,16 @@ func TestStatementHandler_renewLock(t *testing.T) {
projectionName: projectionName,
client: client,
workerName: workerName,
lockStmt: fmt.Sprintf(lockStmtFormat, lockTable),
lockStmt: func(values string, instances int) string {
return fmt.Sprintf(lockStmtFormat, lockTable, values, instances)
},
}
for _, expectation := range tt.want.expectations {
expectation(mock)
}
err = h.renewLock(context.Background(), tt.args.lockDuration, tt.args.instanceID)
err = h.renewLock(context.Background(), tt.args.lockDuration, tt.args.instanceIDs...)
if !tt.want.isErr(err) {
t.Errorf("unexpected error = %v", err)
}
@@ -253,7 +290,9 @@ func TestStatementHandler_Unlock(t *testing.T) {
projectionName: projectionName,
client: client,
workerName: workerName,
lockStmt: fmt.Sprintf(lockStmtFormat, lockTable),
lockStmt: func(values string, instances int) string {
return fmt.Sprintf(lockStmtFormat, lockTable, values, instances)
},
}
for _, expectation := range tt.want.expectations {

View File

@@ -27,3 +27,10 @@ func (h *Handler) Subscribe(aggregates ...eventstore.AggregateType) {
func (h *Handler) SubscribeEvents(types map[eventstore.AggregateType][]eventstore.EventType) {
h.Sub = eventstore.SubscribeEventTypes(h.EventQueue, types)
}
func (h *Handler) Unsubscribe() {
if h.Sub == nil {
return
}
h.Sub.Unsubscribe()
}

View File

@@ -2,13 +2,13 @@ package handler
import (
"context"
"errors"
"runtime/debug"
"sort"
"sync"
"time"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/eventstore"
)
@@ -16,241 +16,207 @@ const systemID = "system"
type ProjectionHandlerConfig struct {
HandlerConfig
ProjectionName string
RequeueEvery time.Duration
RetryFailedAfter time.Duration
ProjectionName string
RequeueEvery time.Duration
RetryFailedAfter time.Duration
Retries uint
ConcurrentInstances uint
}
//Update updates the projection with the given statements
type Update func(context.Context, []*Statement, Reduce) (unexecutedStmts []*Statement, err error)
type Update func(context.Context, []*Statement, Reduce) (index int, err error)
//Reduce reduces the given event to a statement
//which is used to update the projection
type Reduce func(eventstore.Event) (*Statement, error)
//SearchQuery generates the search query to lookup for events
type SearchQuery func(ctx context.Context, instanceIDs []string) (query *eventstore.SearchQueryBuilder, queryLimit uint64, err error)
//Lock is used for mutex handling if needed on the projection
type Lock func(context.Context, time.Duration, string) <-chan error
type Lock func(context.Context, time.Duration, ...string) <-chan error
//Unlock releases the mutex of the projection
type Unlock func(string) error
//SearchQuery generates the search query to lookup for events
type SearchQuery func(ctx context.Context) (query *eventstore.SearchQueryBuilder, queryLimit uint64, err error)
type Unlock func(...string) error
type ProjectionHandler struct {
Handler
requeueAfter time.Duration
shouldBulk *time.Timer
bulkMu sync.Mutex
bulkLocked bool
execBulk executeBulk
retryFailedAfter time.Duration
shouldPush *time.Timer
pushSet bool
ProjectionName string
lockMu sync.Mutex
stmts []*Statement
ProjectionName string
reduce Reduce
update Update
searchQuery SearchQuery
triggerProjection *time.Timer
lock Lock
unlock Unlock
requeueAfter time.Duration
retryFailedAfter time.Duration
retries int
concurrentInstances int
}
func NewProjectionHandler(
ctx context.Context,
config ProjectionHandlerConfig,
reduce Reduce,
update Update,
query SearchQuery,
lock Lock,
unlock Unlock,
) *ProjectionHandler {
concurrentInstances := int(config.ConcurrentInstances)
if concurrentInstances < 1 {
concurrentInstances = 1
}
h := &ProjectionHandler{
Handler: NewHandler(config.HandlerConfig),
ProjectionName: config.ProjectionName,
requeueAfter: config.RequeueEvery,
// first bulk is instant on startup
shouldBulk: time.NewTimer(0),
shouldPush: time.NewTimer(0),
retryFailedAfter: config.RetryFailedAfter,
Handler: NewHandler(config.HandlerConfig),
ProjectionName: config.ProjectionName,
reduce: reduce,
update: update,
searchQuery: query,
lock: lock,
unlock: unlock,
requeueAfter: config.RequeueEvery,
triggerProjection: time.NewTimer(0), // first trigger is instant on startup
retryFailedAfter: config.RetryFailedAfter,
retries: int(config.Retries),
concurrentInstances: concurrentInstances,
}
h.execBulk = h.prepareExecuteBulk(query, reduce, update)
go h.subscribe(ctx)
//unitialized timer
//https://github.com/golang/go/issues/12721
<-h.shouldPush.C
go h.schedule(ctx)
if config.RequeueEvery <= 0 {
if !h.shouldBulk.Stop() {
<-h.shouldBulk.C
}
logging.WithFields("projection", h.ProjectionName).Info("starting handler without requeue")
return h
} else if config.RequeueEvery < 500*time.Millisecond {
logging.WithFields("projection", h.ProjectionName).Fatal("requeue every must be greater 500ms or <= 0")
}
logging.WithFields("projection", h.ProjectionName).Info("starting handler")
return h
}
func (h *ProjectionHandler) ResetShouldBulk() {
if h.requeueAfter > 0 {
h.shouldBulk.Reset(h.requeueAfter)
//Trigger handles all events for the provided instances (or current instance from context if non specified)
//by calling FetchEvents and Process until the amount of events is smaller than the BulkLimit
func (h *ProjectionHandler) Trigger(ctx context.Context, instances ...string) error {
ids := []string{authz.GetInstance(ctx).InstanceID()}
if len(instances) > 0 {
ids = instances
}
}
func (h *ProjectionHandler) triggerShouldPush(after time.Duration) {
if !h.pushSet {
h.pushSet = true
h.shouldPush.Reset(after)
}
}
//Process waits for several conditions:
// if context is canceled the function gracefully shuts down
// if an event occures it reduces the event
// if the internal timer expires the handler will check
// for unprocessed events on eventstore
func (h *ProjectionHandler) Process(
ctx context.Context,
reduce Reduce,
update Update,
lock Lock,
unlock Unlock,
query SearchQuery,
) {
//handle panic
defer func() {
cause := recover()
logging.WithFields("projection", h.ProjectionName, "cause", cause, "stack", string(debug.Stack())).Error("projection handler paniced")
}()
for {
select {
case <-ctx.Done():
if h.pushSet {
h.push(context.Background(), update, reduce)
events, hasLimitExceeded, err := h.FetchEvents(ctx, ids...)
if err != nil {
return err
}
if len(events) == 0 {
return nil
}
_, err = h.Process(ctx, events...)
if err != nil {
return err
}
if !hasLimitExceeded {
return nil
}
}
}
//Process handles multiple events by reducing them to statements and updating the projection
func (h *ProjectionHandler) Process(ctx context.Context, events ...eventstore.Event) (index int, err error) {
if len(events) == 0 {
return 0, nil
}
index = -1
statements := make([]*Statement, len(events))
for i, event := range events {
statements[i], err = h.reduce(event)
if err != nil {
return index, err
}
}
for retry := 0; retry <= h.retries; retry++ {
index, err = h.update(ctx, statements[index+1:], h.reduce)
if err != nil && !errors.Is(err, ErrSomeStmtsFailed) {
return index, err
}
if err == nil {
return index, nil
}
time.Sleep(h.retryFailedAfter)
}
return index, err
}
//FetchEvents checks the current sequences and filters for newer events
func (h *ProjectionHandler) FetchEvents(ctx context.Context, instances ...string) ([]eventstore.Event, bool, error) {
eventQuery, eventsLimit, err := h.searchQuery(ctx, instances)
if err != nil {
return nil, false, err
}
events, err := h.Eventstore.Filter(ctx, eventQuery)
if err != nil {
return nil, false, err
}
return events, int(eventsLimit) == len(events), err
}
func (h *ProjectionHandler) subscribe(ctx context.Context) {
ctx, cancel := context.WithCancel(ctx)
defer func() {
err := recover()
if err != nil {
h.Handler.Unsubscribe()
logging.WithFields("projection", h.ProjectionName).Errorf("subscription panicked: %v", err)
}
cancel()
}()
for firstEvent := range h.EventQueue {
events := checkAdditionalEvents(h.EventQueue, firstEvent)
index, err := h.Process(ctx, events...)
if err != nil || index < len(events)-1 {
logging.WithFields("projection", h.ProjectionName).WithError(err).Error("unable to process all events from subscription")
}
}
}
func (h *ProjectionHandler) schedule(ctx context.Context) {
ctx, cancel := context.WithCancel(ctx)
defer func() {
err := recover()
if err != nil {
logging.WithFields("projection", h.ProjectionName, "cause", err, "stack", string(debug.Stack())).Error("schedule panicked")
}
cancel()
}()
for range h.triggerProjection.C {
ids, err := h.Eventstore.InstanceIDs(ctx, eventstore.NewSearchQueryBuilder(eventstore.ColumnsInstanceIDs).AddQuery().ExcludedInstanceID("").Builder())
if err != nil {
logging.WithFields("projection", h.ProjectionName).WithError(err).Error("instance ids")
h.triggerProjection.Reset(h.requeueAfter)
continue
}
for i := 0; i < len(ids); i = i + h.concurrentInstances {
max := i + h.concurrentInstances
if max > len(ids) {
max = len(ids)
}
h.shutdown()
return
case event := <-h.EventQueue:
if err := h.processEvent(ctx, event, reduce); err != nil {
logging.WithFields("projection", h.ProjectionName).WithError(err).Warn("process failed")
instances := ids[i:max]
lockCtx, cancelLock := context.WithCancel(ctx)
errs := h.lock(lockCtx, h.requeueAfter, instances...)
//wait until projection is locked
if err, ok := <-errs; err != nil || !ok {
cancelLock()
logging.WithFields("projection", h.ProjectionName).OnError(err).Warn("initial lock failed")
continue
}
h.triggerShouldPush(0)
case <-h.shouldBulk.C:
h.bulkMu.Lock()
h.bulkLocked = true
h.bulk(ctx, lock, unlock)
h.ResetShouldBulk()
h.bulkLocked = false
h.bulkMu.Unlock()
default:
//lower prio select with push
select {
case <-ctx.Done():
if h.pushSet {
h.push(context.Background(), update, reduce)
}
h.shutdown()
return
case event := <-h.EventQueue:
if err := h.processEvent(ctx, event, reduce); err != nil {
logging.WithFields("projection", h.ProjectionName).WithError(err).Warn("process failed")
continue
}
h.triggerShouldPush(0)
case <-h.shouldBulk.C:
h.bulkMu.Lock()
h.bulkLocked = true
h.bulk(ctx, lock, unlock)
h.ResetShouldBulk()
h.bulkLocked = false
h.bulkMu.Unlock()
case <-h.shouldPush.C:
h.push(ctx, update, reduce)
h.ResetShouldBulk()
go h.cancelOnErr(lockCtx, errs, cancelLock)
err = h.Trigger(lockCtx, instances...)
if err != nil {
logging.WithFields("projection", h.ProjectionName, "instanceIDs", instances).WithError(err).Error("trigger failed")
}
cancelLock()
unlockErr := h.unlock(instances...)
logging.WithFields("projection", h.ProjectionName).OnError(unlockErr).Warn("unable to unlock")
}
h.triggerProjection.Reset(h.requeueAfter)
}
}
func (h *ProjectionHandler) processEvent(
ctx context.Context,
event eventstore.Event,
reduce Reduce,
) error {
stmt, err := reduce(event)
if err != nil {
logging.New().WithError(err).Warn("unable to process event")
return err
}
h.lockMu.Lock()
defer h.lockMu.Unlock()
h.stmts = append(h.stmts, stmt)
return nil
}
func (h *ProjectionHandler) TriggerBulk(
ctx context.Context,
lock Lock,
unlock Unlock,
) error {
if !h.shouldBulk.Stop() {
//make sure to flush shouldBulk chan
select {
case <-h.shouldBulk.C:
default:
}
}
defer h.ResetShouldBulk()
h.bulkMu.Lock()
if h.bulkLocked {
logging.WithFields("projection", h.ProjectionName).Debugf("waiting for existing bulk to finish")
h.bulkMu.Unlock()
return nil
}
h.bulkLocked = true
defer func() {
h.bulkLocked = false
h.bulkMu.Unlock()
}()
return h.bulk(ctx, lock, unlock)
}
func (h *ProjectionHandler) bulk(
ctx context.Context,
lock Lock,
unlock Unlock,
) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
errs := lock(ctx, h.requeueAfter, systemID)
//wait until projection is locked
if err, ok := <-errs; err != nil || !ok {
logging.WithFields("projection", h.ProjectionName).OnError(err).Warn("initial lock failed")
return err
}
go h.cancelOnErr(ctx, errs, cancel)
execErr := h.execBulk(ctx)
logging.WithFields("projection", h.ProjectionName).OnError(execErr).Warn("unable to execute")
unlockErr := unlock(systemID)
logging.WithFields("projection", h.ProjectionName).OnError(unlockErr).Warn("unable to unlock")
if execErr != nil {
return execErr
}
return unlockErr
}
func (h *ProjectionHandler) cancelOnErr(ctx context.Context, errs <-chan error, cancel func()) {
for {
select {
@@ -268,98 +234,15 @@ func (h *ProjectionHandler) cancelOnErr(ctx context.Context, errs <-chan error,
}
}
type executeBulk func(ctx context.Context) error
func (h *ProjectionHandler) prepareExecuteBulk(
query SearchQuery,
reduce Reduce,
update Update,
) executeBulk {
return func(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return nil
default:
hasLimitExeeded, err := h.fetchBulkStmts(ctx, query, reduce)
if err != nil || len(h.stmts) == 0 {
logging.WithFields("projection", h.ProjectionName).OnError(err).Warn("unable to fetch stmts")
return err
}
if err = h.push(ctx, update, reduce); err != nil {
return err
}
if !hasLimitExeeded {
return nil
}
}
func checkAdditionalEvents(eventQueue chan eventstore.Event, event eventstore.Event) []eventstore.Event {
events := make([]eventstore.Event, 1)
events[0] = event
for {
select {
case event := <-eventQueue:
events = append(events, event)
default:
return events
}
}
}
func (h *ProjectionHandler) fetchBulkStmts(
ctx context.Context,
query SearchQuery,
reduce Reduce,
) (limitExeeded bool, err error) {
eventQuery, eventsLimit, err := query(ctx)
if err != nil {
logging.WithFields("projection", h.ProjectionName).WithError(err).Warn("unable to create event query")
return false, err
}
events, err := h.Eventstore.Filter(ctx, eventQuery)
if err != nil {
logging.WithFields("projection", h.ProjectionName).WithError(err).Info("Unable to bulk fetch events")
return false, err
}
for _, event := range events {
if err = h.processEvent(ctx, event, reduce); err != nil {
logging.WithFields("projection", h.ProjectionName, "sequence", event.Sequence(), "instanceID", event.Aggregate().InstanceID).WithError(err).Warn("unable to process event in bulk")
return false, err
}
}
return len(events) == int(eventsLimit), nil
}
func (h *ProjectionHandler) push(
ctx context.Context,
update Update,
reduce Reduce,
) (err error) {
h.lockMu.Lock()
defer h.lockMu.Unlock()
sort.Slice(h.stmts, func(i, j int) bool {
return h.stmts[i].Sequence < h.stmts[j].Sequence
})
h.stmts, err = update(ctx, h.stmts, reduce)
h.pushSet = len(h.stmts) > 0
if h.pushSet {
h.triggerShouldPush(h.retryFailedAfter)
return nil
}
h.shouldPush.Stop()
return err
}
func (h *ProjectionHandler) shutdown() {
h.lockMu.Lock()
defer h.lockMu.Unlock()
h.Sub.Unsubscribe()
if !h.shouldBulk.Stop() {
<-h.shouldBulk.C
}
if !h.shouldPush.Stop() {
<-h.shouldPush.C
}
logging.New().Info("stop processing")
}

File diff suppressed because it is too large Load Diff

View File

@@ -8,8 +8,8 @@ import (
context "context"
reflect "reflect"
repository "github.com/zitadel/zitadel/internal/eventstore/repository"
gomock "github.com/golang/mock/gomock"
repository "github.com/zitadel/zitadel/internal/eventstore/repository"
)
// MockRepository is a mock of Repository interface.
@@ -78,6 +78,21 @@ func (mr *MockRepositoryMockRecorder) Health(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockRepository)(nil).Health), arg0)
}
// InstanceIDs mocks base method.
func (m *MockRepository) InstanceIDs(arg0 context.Context, arg1 *repository.SearchQuery) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InstanceIDs", arg0, arg1)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InstanceIDs indicates an expected call of InstanceIDs.
func (mr *MockRepositoryMockRecorder) InstanceIDs(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InstanceIDs", reflect.TypeOf((*MockRepository)(nil).InstanceIDs), arg0, arg1)
}
// LatestSequence mocks base method.
func (m *MockRepository) LatestSequence(arg0 context.Context, arg1 *repository.SearchQuery) (uint64, error) {
m.ctrl.T.Helper()

View File

@@ -29,6 +29,16 @@ func (m *MockRepository) ExpectFilterEventsError(err error) *MockRepository {
return m
}
func (m *MockRepository) ExpectInstanceIDs(instanceIDs ...string) *MockRepository {
m.EXPECT().InstanceIDs(gomock.Any(), gomock.Any()).Return(instanceIDs, nil)
return m
}
func (m *MockRepository) ExpectInstanceIDsError(err error) *MockRepository {
m.EXPECT().InstanceIDs(gomock.Any(), gomock.Any()).Return(nil, err)
return m
}
func (m *MockRepository) ExpectPush(expectedEvents []*repository.Event, expectedUniqueConstraints ...*repository.UniqueConstraint) *MockRepository {
m.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(ctx context.Context, events []*repository.Event, uniqueConstraints ...*repository.UniqueConstraint) error {

View File

@@ -8,14 +8,16 @@ import (
type Repository interface {
//Health checks if the connection to the storage is available
Health(ctx context.Context) error
// PushEvents adds all events of the given aggregates to the eventstreams of the aggregates.
// Push adds all events of the given aggregates to the event streams of the aggregates.
// if unique constraints are pushed, they will be added to the unique table for checking unique constraint violations
// This call is transaction save. The transaction will be rolled back if one event fails
Push(ctx context.Context, events []*Event, uniqueConstraints ...*UniqueConstraint) error
// Filter returns all events matching the given search query
Filter(ctx context.Context, searchQuery *SearchQuery) (events []*Event, err error)
//LatestSequence returns the latests sequence found by the the search query
//LatestSequence returns the latest sequence found by the search query
LatestSequence(ctx context.Context, queryFactory *SearchQuery) (uint64, error)
//InstanceIDs returns the instance ids found by the search query
InstanceIDs(ctx context.Context, queryFactory *SearchQuery) ([]string, error)
//CreateInstance creates a new sequence for the given instance
CreateInstance(ctx context.Context, instanceID string) error
}

View File

@@ -23,6 +23,8 @@ const (
ColumnsEvent = iota + 1
//ColumnsMaxSequence represents the latest sequence of the filtered events
ColumnsMaxSequence
// ColumnsInstanceIDs represents the instance ids of the filtered events
ColumnsInstanceIDs
columnsCount
)

View File

@@ -218,7 +218,7 @@ func (db *CRDB) Filter(ctx context.Context, searchQuery *repository.SearchQuery)
return events, nil
}
//LatestSequence returns the latests sequence found by the the search query
//LatestSequence returns the latest sequence found by the search query
func (db *CRDB) LatestSequence(ctx context.Context, searchQuery *repository.SearchQuery) (uint64, error) {
var seq Sequence
err := query(ctx, db, searchQuery, &seq)
@@ -228,6 +228,16 @@ func (db *CRDB) LatestSequence(ctx context.Context, searchQuery *repository.Sear
return uint64(seq), nil
}
//InstanceIDs returns the instance ids found by the search query
func (db *CRDB) InstanceIDs(ctx context.Context, searchQuery *repository.SearchQuery) ([]string, error) {
var ids []string
err := query(ctx, db, searchQuery, &ids)
if err != nil {
return nil, err
}
return ids, nil
}
func (db *CRDB) db() *sql.DB {
return db.client
}
@@ -262,6 +272,10 @@ func (db *CRDB) maxSequenceQuery() string {
return "SELECT MAX(event_sequence) FROM eventstore.events"
}
func (db *CRDB) instanceIDsQuery() string {
return "SELECT DISTINCT instance_id FROM eventstore.events"
}
func (db *CRDB) columnName(col repository.Field) string {
switch col {
case repository.FieldAggregateID:

View File

@@ -22,6 +22,7 @@ type querier interface {
placeholder(query string) string
eventQuery() string
maxSequenceQuery() string
instanceIDsQuery() string
db() *sql.DB
orderByEventSequence(desc bool) string
}
@@ -36,7 +37,7 @@ func query(ctx context.Context, criteria querier, searchQuery *repository.Search
}
query += where
if searchQuery.Columns != repository.ColumnsMaxSequence {
if searchQuery.Columns == repository.ColumnsEvent {
query += criteria.orderByEventSequence(searchQuery.Desc)
}
@@ -76,6 +77,8 @@ func prepareColumns(criteria querier, columns repository.Columns) (string, func(
switch columns {
case repository.ColumnsMaxSequence:
return criteria.maxSequenceQuery(), maxSequenceScanner
case repository.ColumnsInstanceIDs:
return criteria.instanceIDsQuery(), instanceIDsScanner
case repository.ColumnsEvent:
return criteria.eventQuery(), eventsScanner
default:
@@ -95,6 +98,22 @@ func maxSequenceScanner(row scan, dest interface{}) (err error) {
return z_errors.ThrowInternal(err, "SQL-bN5xg", "something went wrong")
}
func instanceIDsScanner(scanner scan, dest interface{}) (err error) {
ids, ok := dest.(*[]string)
if !ok {
return z_errors.ThrowInvalidArgument(nil, "SQL-Begh2", "type must be an array of string")
}
var id string
err = scanner(&id)
if err != nil {
logging.WithError(err).Warn("unable to scan row")
return z_errors.ThrowInternal(err, "SQL-DEFGe", "unable to scan row")
}
*ids = append(*ids, id)
return nil
}
func eventsScanner(scanner scan, dest interface{}) (err error) {
events, ok := dest.(*[]*repository.Event)
if !ok {
@@ -157,7 +176,7 @@ func prepareCondition(criteria querier, filters [][]*repository.Filter) (clause
var err error
value, err = json.Marshal(value)
if err != nil {
logging.New().WithError(err).Warn("unable to marshal search value")
logging.WithError(err).Warn("unable to marshal search value")
continue
}
}

View File

@@ -39,6 +39,8 @@ const (
ColumnsEvent Columns = repository.ColumnsEvent
// ColumnsMaxSequence represents the latest sequence of the filtered events
ColumnsMaxSequence Columns = repository.ColumnsMaxSequence
// ColumnsInstanceIDs represents the instance ids of the filtered events
ColumnsInstanceIDs Columns = repository.ColumnsInstanceIDs
)
// AggregateType is the object name
@@ -278,6 +280,9 @@ func (query *SearchQuery) eventTypeFilter() *repository.Filter {
}
func (query *SearchQuery) aggregateTypeFilter() *repository.Filter {
if len(query.aggregateTypes) < 1 {
return nil
}
if len(query.aggregateTypes) == 1 {
return repository.NewFilter(repository.FieldAggregateType, repository.AggregateType(query.aggregateTypes[0]), repository.OperationEquals)
}

View File

@@ -13,6 +13,7 @@ type Eventstore interface {
Health(ctx context.Context) error
FilterEvents(ctx context.Context, searchQuery *models.SearchQuery) (events []*models.Event, err error)
Subscribe(aggregates ...models.AggregateType) *Subscription
InstanceIDs(ctx context.Context, searchQuery *models.SearchQuery) ([]string, error)
}
var _ Eventstore = (*eventstore)(nil)
@@ -37,3 +38,10 @@ func (es *eventstore) FilterEvents(ctx context.Context, searchQuery *models.Sear
func (es *eventstore) Health(ctx context.Context) error {
return es.repo.Health(ctx)
}
func (es *eventstore) InstanceIDs(ctx context.Context, searchQuery *models.SearchQuery) ([]string, error) {
if err := searchQuery.Validate(); err != nil {
return nil, err
}
return es.repo.InstanceIDs(ctx, models.FactoryFromSearchQuery(searchQuery))
}

View File

@@ -11,6 +11,8 @@ type Repository interface {
// Filter returns all events matching the given search query
Filter(ctx context.Context, searchQuery *models.SearchQueryFactory) (events []*models.Event, err error)
//LatestSequence returns the latests sequence found by the the search query
//LatestSequence returns the latest sequence found by the search query
LatestSequence(ctx context.Context, queryFactory *models.SearchQueryFactory) (uint64, error)
//InstanceIDs returns the instance ids found by the search query
InstanceIDs(ctx context.Context, queryFactory *models.SearchQueryFactory) ([]string, error)
}

View File

@@ -5,6 +5,7 @@ import (
"database/sql"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/errors"
es_models "github.com/zitadel/zitadel/internal/eventstore/v1/models"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
@@ -60,3 +61,31 @@ func (db *SQL) LatestSequence(ctx context.Context, queryFactory *es_models.Searc
}
return uint64(*sequence), nil
}
func (db *SQL) InstanceIDs(ctx context.Context, queryFactory *es_models.SearchQueryFactory) ([]string, error) {
query, _, values, rowScanner := buildQuery(queryFactory)
if query == "" {
return nil, errors.ThrowInvalidArgument(nil, "SQL-Sfwg2", "invalid query factory")
}
rows, err := db.client.Query(query, values...)
if err != nil {
logging.New().WithError(err).Info("query failed")
return nil, errors.ThrowInternal(err, "SQL-Sfg3r", "unable to filter instance ids")
}
defer rows.Close()
ids := make([]string, 0)
for rows.Next() {
var id string
err := rowScanner(rows.Scan, &id)
if err != nil {
return nil, err
}
ids = append(ids, id)
}
return ids, nil
}

View File

@@ -44,7 +44,7 @@ func buildQuery(queryFactory *es_models.SearchQueryFactory) (query string, limit
}
query += where
if searchQuery.Columns != es_models.Columns_Max_Sequence {
if searchQuery.Columns == es_models.Columns_Event {
query += " ORDER BY event_sequence"
if searchQuery.Desc {
query += " DESC"
@@ -104,6 +104,19 @@ func prepareColumns(columns es_models.Columns) (string, func(s scan, dest interf
}
return z_errors.ThrowInternal(err, "SQL-bN5xg", "something went wrong")
}
case es_models.Columns_InstanceIDs:
return "SELECT DISTINCT instance_id FROM eventstore.events", func(row scan, dest interface{}) (err error) {
instanceID, ok := dest.(*string)
if !ok {
return z_errors.ThrowInvalidArgument(nil, "SQL-Fef5h", "type must be *string]")
}
err = row(instanceID)
if err != nil {
logging.New().WithError(err).Warn("unable to scan row")
return z_errors.ThrowInternal(err, "SQL-SFef3", "unable to scan row")
}
return nil
}
case es_models.Columns_Event:
return selectStmt, func(row scan, dest interface{}) (err error) {
event, ok := dest.(*es_models.Event)

View File

@@ -41,6 +41,7 @@ type Columns int32
const (
Columns_Event = iota
Columns_Max_Sequence
Columns_InstanceIDs
//insert new columns-types before this columnsCount because count is needed for validation
columnsCount
)
@@ -48,7 +49,7 @@ const (
//FactoryFromSearchQuery is deprecated because it's for migration purposes. use NewSearchQueryFactory
func FactoryFromSearchQuery(q *SearchQuery) *SearchQueryFactory {
factory := &SearchQueryFactory{
columns: Columns_Event,
columns: q.Columns,
desc: q.Desc,
limit: q.Limit,
queries: make([]*query, len(q.Queries)),
@@ -232,6 +233,9 @@ func (q *query) eventTypeFilter() *Filter {
}
func (q *query) aggregateTypeFilter() *Filter {
if len(q.aggregateTypes) < 1 {
return nil
}
if len(q.aggregateTypes) == 1 {
return NewFilter(Field_AggregateType, q.aggregateTypes[0], Operation_Equals)
}

View File

@@ -8,6 +8,7 @@ import (
//SearchQuery is deprecated. Use SearchQueryFactory
type SearchQuery struct {
Columns Columns
Limit uint64
Desc bool
Filters []*Filter
@@ -27,6 +28,11 @@ func NewSearchQuery() *SearchQuery {
}
}
func (q *SearchQuery) SetColumn(columns Columns) *SearchQuery {
q.Columns = columns
return q
}
func (q *SearchQuery) AddQuery() *Query {
query := &Query{
searchQuery: q,

View File

@@ -2,9 +2,9 @@ package query
import (
"context"
"runtime/debug"
"time"
"github.com/getsentry/sentry-go"
"github.com/zitadel/logging"
v1 "github.com/zitadel/zitadel/internal/eventstore/v1"
@@ -17,7 +17,7 @@ const (
type Handler interface {
ViewModel() string
EventQuery() (*models.SearchQuery, error)
EventQuery(instanceIDs ...string) (*models.SearchQuery, error)
Reduce(*models.Event) error
OnError(event *models.Event, err error) error
OnSuccess() error
@@ -37,14 +37,13 @@ func ReduceEvent(handler Handler, event *models.Event) {
err := recover()
if err != nil {
sentry.CurrentHub().Recover(err)
handler.Subscription().Unsubscribe()
logging.WithFields("HANDL-SAFe1").Errorf("reduce panicked: %v", err)
logging.WithFields("cause", err, "stack", string(debug.Stack())).Error("reduce panicked")
}
}()
currentSequence, err := handler.CurrentSequence(event.InstanceID)
if err != nil {
logging.New().WithError(err).Warn("unable to get current sequence")
logging.WithError(err).Warn("unable to get current sequence")
return
}
@@ -58,14 +57,14 @@ func ReduceEvent(handler Handler, event *models.Event) {
unprocessedEvents, err := handler.Eventstore().FilterEvents(context.Background(), searchQuery)
if err != nil {
logging.WithFields("HANDL-L6YH1", "sequence", event.Sequence).Warn("filter failed")
logging.WithFields("sequence", event.Sequence).Warn("filter failed")
return
}
for _, unprocessedEvent := range unprocessedEvents {
currentSequence, err := handler.CurrentSequence(unprocessedEvent.InstanceID)
if err != nil {
logging.Log("HANDL-BmpkC").WithError(err).Warn("unable to get current sequence")
logging.WithError(err).Warn("unable to get current sequence")
return
}
if unprocessedEvent.Sequence < currentSequence {
@@ -78,12 +77,12 @@ func ReduceEvent(handler Handler, event *models.Event) {
}
err = handler.Reduce(unprocessedEvent)
logging.WithFields("HANDL-V42TI", "sequence", unprocessedEvent.Sequence).OnError(err).Warn("reduce failed")
logging.WithFields("sequence", unprocessedEvent.Sequence).OnError(err).Warn("reduce failed")
}
if len(unprocessedEvents) == eventLimit {
logging.WithFields("QUERY-BSqe9", "sequence", event.Sequence).Warn("didnt process event")
logging.WithFields("sequence", event.Sequence).Warn("didnt process event")
return
}
err = handler.Reduce(event)
logging.WithFields("HANDL-wQDL2", "sequence", event.Sequence).OnError(err).Warn("reduce failed")
logging.WithFields("sequence", event.Sequence).OnError(err).Warn("reduce failed")
}

View File

@@ -11,10 +11,11 @@ import (
)
type Config struct {
Eventstore v1.Eventstore
Locker Locker
ViewHandlers []query.Handler
ConcurrentWorkers int
Eventstore v1.Eventstore
Locker Locker
ViewHandlers []query.Handler
ConcurrentWorkers int
ConcurrentInstances int
}
func (c *Config) New() *Spooler {
@@ -27,11 +28,12 @@ func (c *Config) New() *Spooler {
})
return &Spooler{
handlers: c.ViewHandlers,
lockID: lockID,
eventstore: c.Eventstore,
locker: c.Locker,
queue: make(chan *spooledHandler, len(c.ViewHandlers)),
workers: c.ConcurrentWorkers,
handlers: c.ViewHandlers,
lockID: lockID,
eventstore: c.Eventstore,
locker: c.Locker,
queue: make(chan *spooledHandler, len(c.ViewHandlers)),
workers: c.ConcurrentWorkers,
concurrentInstances: c.ConcurrentInstances,
}
}

View File

@@ -2,11 +2,11 @@ package spooler
import (
"context"
"runtime/debug"
"strconv"
"sync"
"time"
"github.com/getsentry/sentry-go"
"github.com/zitadel/logging"
v1 "github.com/zitadel/zitadel/internal/eventstore/v1"
@@ -19,12 +19,13 @@ import (
const systemID = "system"
type Spooler struct {
handlers []query.Handler
locker Locker
lockID string
eventstore v1.Eventstore
workers int
queue chan *spooledHandler
handlers []query.Handler
locker Locker
lockID string
eventstore v1.Eventstore
workers int
queue chan *spooledHandler
concurrentInstances int
}
type Locker interface {
@@ -33,9 +34,10 @@ type Locker interface {
type spooledHandler struct {
query.Handler
locker Locker
queuedAt time.Time
eventstore v1.Eventstore
locker Locker
queuedAt time.Time
eventstore v1.Eventstore
concurrentInstances int
}
func (s *Spooler) Start() {
@@ -55,7 +57,7 @@ func (s *Spooler) Start() {
}
go func() {
for _, handler := range s.handlers {
s.queue <- &spooledHandler{Handler: handler, locker: s.locker, queuedAt: time.Now(), eventstore: s.eventstore}
s.queue <- &spooledHandler{Handler: handler, locker: s.locker, queuedAt: time.Now(), eventstore: s.eventstore, concurrentInstances: s.concurrentInstances}
}
}()
}
@@ -73,7 +75,7 @@ func (s *spooledHandler) load(workerID string) {
err := recover()
if err != nil {
sentry.CurrentHub().Recover(err)
logging.WithFields("cause", err, "stack", string(debug.Stack())).Error("reduce panicked")
}
}()
ctx, cancel := context.WithCancel(context.Background())
@@ -82,29 +84,50 @@ func (s *spooledHandler) load(workerID string) {
if <-hasLocked {
for {
events, err := s.query(ctx)
ids, err := s.eventstore.InstanceIDs(ctx, models.NewSearchQuery().SetColumn(models.Columns_InstanceIDs).AddQuery().ExcludedInstanceIDsFilter("").SearchQuery())
if err != nil {
errs <- err
break
}
err = s.process(ctx, events, workerID)
if err != nil {
errs <- err
break
}
if uint64(len(events)) < s.QueryLimit() {
// no more events to process
// stop chan
if ctx.Err() == nil {
errs <- nil
for i := 0; i < len(ids); i = i + s.concurrentInstances {
max := i + s.concurrentInstances
if max > len(ids) {
max = len(ids)
}
err = s.processInstances(ctx, workerID, ids[i:max]...)
if err != nil {
errs <- err
}
break
}
if ctx.Err() == nil {
errs <- nil
}
break
}
}
<-ctx.Done()
}
func (s *spooledHandler) processInstances(ctx context.Context, workerID string, ids ...string) error {
for {
events, err := s.query(ctx, ids...)
if err != nil {
return err
}
if len(events) == 0 {
return nil
}
err = s.process(ctx, events, workerID)
if err != nil {
return err
}
if uint64(len(events)) < s.QueryLimit() {
// no more events to process
return nil
}
}
}
func (s *spooledHandler) awaitError(cancel func(), errs chan error, workerID string) {
select {
case err := <-errs:
@@ -135,8 +158,8 @@ func (s *spooledHandler) process(ctx context.Context, events []*models.Event, wo
return err
}
func (s *spooledHandler) query(ctx context.Context) ([]*models.Event, error) {
query, err := s.EventQuery()
func (s *spooledHandler) query(ctx context.Context, instanceIDs ...string) ([]*models.Event, error) {
query, err := s.EventQuery(instanceIDs...)
if err != nil {
return nil, err
}

View File

@@ -47,7 +47,7 @@ func (h *testHandler) Subscription() *v1.Subscription {
return nil
}
func (h *testHandler) EventQuery() (*models.SearchQuery, error) {
func (h *testHandler) EventQuery(instanceIDs ...string) (*models.SearchQuery, error) {
if h.queryError != nil {
return nil, h.queryError
}
@@ -111,6 +111,9 @@ func (es *eventstoreStub) PushAggregates(ctx context.Context, in ...*models.Aggr
func (es *eventstoreStub) LatestSequence(ctx context.Context, in *models.SearchQueryFactory) (uint64, error) {
return 0, nil
}
func (es *eventstoreStub) InstanceIDs(ctx context.Context, in *models.SearchQuery) ([]string, error) {
return nil, nil
}
func (es *eventstoreStub) V2() *eventstore.Eventstore {
return nil
}

View File

@@ -208,7 +208,7 @@ var (
func (q *Queries) AppByProjectAndAppID(ctx context.Context, shouldTriggerBulk bool, projectID, appID string) (*App, error) {
if shouldTriggerBulk {
projection.AppProjection.TriggerBulk(ctx)
projection.AppProjection.Trigger(ctx)
}
stmt, scan := prepareAppQuery()

View File

@@ -124,7 +124,7 @@ func (q *Queries) SearchAuthNKeys(ctx context.Context, queries *AuthNKeySearchQu
func (q *Queries) GetAuthNKeyByID(ctx context.Context, shouldTriggerBulk bool, id string, queries ...SearchQuery) (*AuthNKey, error) {
if shouldTriggerBulk {
projection.AuthNKeyProjection.TriggerBulk(ctx)
projection.AuthNKeyProjection.Trigger(ctx)
}
query, scan := prepareAuthNKeyQuery()

View File

@@ -82,7 +82,7 @@ var (
func (q *Queries) DomainPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string) (*DomainPolicy, error) {
if shouldTriggerBulk {
projection.DomainPolicyProjection.TriggerBulk(ctx)
projection.DomainPolicyProjection.Trigger(ctx)
}
stmt, scan := prepareDomainPolicyQuery()

View File

@@ -115,7 +115,7 @@ func (q *FailedEventSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuil
return query
}
func prepareFailedEventQuery() (sq.SelectBuilder, func(*sql.Row) (*FailedEvent, error)) {
func prepareFailedEventQuery(instanceIDs ...string) (sq.SelectBuilder, func(*sql.Row) (*FailedEvent, error)) {
return sq.Select(
FailedEventsColumnProjectionName.identifier(),
FailedEventsColumnFailedSequence.identifier(),

View File

@@ -182,7 +182,7 @@ var (
//IDPByIDAndResourceOwner searches for the requested id in the context of the resource owner and IAM
func (q *Queries) IDPByIDAndResourceOwner(ctx context.Context, shouldTriggerBulk bool, id, resourceOwner string) (*IDP, error) {
if shouldTriggerBulk {
projection.IDPProjection.TriggerBulk(ctx)
projection.IDPProjection.Trigger(ctx)
}
stmt, scan := prepareIDPByIDQuery()

View File

@@ -159,7 +159,7 @@ func (q *Queries) SearchInstances(ctx context.Context, queries *InstanceSearchQu
func (q *Queries) Instance(ctx context.Context, shouldTriggerBulk bool) (*Instance, error) {
if shouldTriggerBulk {
projection.InstanceProjection.TriggerBulk(ctx)
projection.InstanceProjection.Trigger(ctx)
}
stmt, scan := prepareInstanceDomainQuery(authz.GetInstance(ctx).RequestedDomain())

View File

@@ -77,7 +77,7 @@ var (
func (q *Queries) LockoutPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string) (*LockoutPolicy, error) {
if shouldTriggerBulk {
projection.LockoutPolicyProjection.TriggerBulk(ctx)
projection.LockoutPolicyProjection.Trigger(ctx)
}
stmt, scan := prepareLockoutPolicyQuery()

View File

@@ -141,7 +141,7 @@ var (
func (q *Queries) LoginPolicyByID(ctx context.Context, shouldTriggerBulk bool, orgID string) (*LoginPolicy, error) {
if shouldTriggerBulk {
projection.LoginPolicyProjection.TriggerBulk(ctx)
projection.LoginPolicyProjection.Trigger(ctx)
}
query, scan := prepareLoginPolicyQuery()

View File

@@ -88,7 +88,7 @@ func (q *OrgSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
func (q *Queries) OrgByID(ctx context.Context, shouldTriggerBulk bool, id string) (*Org, error) {
if shouldTriggerBulk {
projection.OrgProjection.TriggerBulk(ctx)
projection.OrgProjection.Trigger(ctx)
}
stmt, scan := prepareOrgQuery()

View File

@@ -76,7 +76,7 @@ var (
func (q *Queries) PasswordAgePolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string) (*PasswordAgePolicy, error) {
if shouldTriggerBulk {
projection.PasswordAgeProjection.TriggerBulk(ctx)
projection.PasswordAgeProjection.Trigger(ctx)
}
stmt, scan := preparePasswordAgePolicyQuery()
@@ -106,7 +106,7 @@ func (q *Queries) PasswordAgePolicyByOrg(ctx context.Context, shouldTriggerBulk
func (q *Queries) DefaultPasswordAgePolicy(ctx context.Context, shouldTriggerBulk bool) (*PasswordAgePolicy, error) {
if shouldTriggerBulk {
projection.PasswordAgeProjection.TriggerBulk(ctx)
projection.PasswordAgeProjection.Trigger(ctx)
}
stmt, scan := preparePasswordAgePolicyQuery()

View File

@@ -33,7 +33,7 @@ type PasswordComplexityPolicy struct {
func (q *Queries) PasswordComplexityPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string) (*PasswordComplexityPolicy, error) {
if shouldTriggerBulk {
projection.PasswordComplexityProjection.TriggerBulk(ctx)
projection.PasswordComplexityProjection.Trigger(ctx)
}
stmt, scan := preparePasswordComplexityPolicyQuery()
@@ -63,7 +63,7 @@ func (q *Queries) PasswordComplexityPolicyByOrg(ctx context.Context, shouldTrigg
func (q *Queries) DefaultPasswordComplexityPolicy(ctx context.Context, shouldTriggerBulk bool) (*PasswordComplexityPolicy, error) {
if shouldTriggerBulk {
projection.PasswordComplexityProjection.TriggerBulk(ctx)
projection.PasswordComplexityProjection.Trigger(ctx)
}
stmt, scan := preparePasswordComplexityPolicyQuery()

View File

@@ -81,7 +81,7 @@ var (
func (q *Queries) PrivacyPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string) (*PrivacyPolicy, error) {
if shouldTriggerBulk {
projection.PrivacyPolicyProjection.TriggerBulk(ctx)
projection.PrivacyPolicyProjection.Trigger(ctx)
}
stmt, scan := preparePrivacyPolicyQuery()
@@ -111,7 +111,7 @@ func (q *Queries) PrivacyPolicyByOrg(ctx context.Context, shouldTriggerBulk bool
func (q *Queries) DefaultPrivacyPolicy(ctx context.Context, shouldTriggerBulk bool) (*PrivacyPolicy, error) {
if shouldTriggerBulk {
projection.PrivacyPolicyProjection.TriggerBulk(ctx)
projection.PrivacyPolicyProjection.Trigger(ctx)
}
stmt, scan := preparePrivacyPolicyQuery()

View File

@@ -96,7 +96,7 @@ type ProjectSearchQueries struct {
func (q *Queries) ProjectByID(ctx context.Context, shouldTriggerBulk bool, id string) (*Project, error) {
if shouldTriggerBulk {
projection.ProjectProjection.TriggerBulk(ctx)
projection.ProjectProjection.Trigger(ctx)
}
stmt, scan := prepareProjectQuery()

View File

@@ -103,7 +103,7 @@ type ProjectGrantSearchQueries struct {
func (q *Queries) ProjectGrantByID(ctx context.Context, shouldTriggerBulk bool, id string) (*ProjectGrant, error) {
if shouldTriggerBulk {
projection.ProjectGrantProjection.TriggerBulk(ctx)
projection.ProjectGrantProjection.Trigger(ctx)
}
stmt, scan := prepareProjectGrantQuery()

View File

@@ -78,7 +78,7 @@ type ProjectRoleSearchQueries struct {
func (q *Queries) SearchProjectRoles(ctx context.Context, shouldTriggerBulk bool, queries *ProjectRoleSearchQueries) (projects *ProjectRoles, err error) {
if shouldTriggerBulk {
projection.ProjectRoleProjection.TriggerBulk(ctx)
projection.ProjectRoleProjection.Trigger(ctx)
}
query, scan := prepareProjectRolesQuery()

View File

@@ -5,17 +5,19 @@ import (
)
type Config struct {
RequeueEvery time.Duration
RetryFailedAfter time.Duration
MaxFailureCount uint
BulkLimit uint64
Customizations map[string]CustomConfig
MaxIterators int
RequeueEvery time.Duration
RetryFailedAfter time.Duration
MaxFailureCount uint
ConcurrentInstances uint
BulkLimit uint64
Customizations map[string]CustomConfig
MaxIterators int
}
type CustomConfig struct {
RequeueEvery *time.Duration
RetryFailedAfter *time.Duration
MaxFailureCount *uint
BulkLimit *uint64
RequeueEvery *time.Duration
RetryFailedAfter *time.Duration
MaxFailureCount *uint
ConcurrentInstances *uint
BulkLimit *uint64
}

View File

@@ -83,8 +83,8 @@ func newKeyProjection(ctx context.Context, config crdb.StatementHandlerConfig, k
crdb.WithForeignKey(crdb.NewForeignKeyOfPublicKeys("fk_public_ref_keys")),
),
)
p.StatementHandler = crdb.NewStatementHandler(ctx, config)
p.encryptionAlgorithm = keyEncryptionAlgorithm
p.StatementHandler = crdb.NewStatementHandler(ctx, config)
return p
}

View File

@@ -68,8 +68,10 @@ func Start(ctx context.Context, sqlClient *sql.DB, es *eventstore.Eventstore, co
HandlerConfig: handler.HandlerConfig{
Eventstore: es,
},
RequeueEvery: config.RequeueEvery,
RetryFailedAfter: config.RetryFailedAfter,
RequeueEvery: config.RequeueEvery,
RetryFailedAfter: config.RetryFailedAfter,
Retries: config.MaxFailureCount,
ConcurrentInstances: config.ConcurrentInstances,
},
Client: sqlClient,
SequenceTable: CurrentSeqTable,

View File

@@ -294,8 +294,8 @@ var (
func (q *Queries) GetUserByID(ctx context.Context, shouldTriggerBulk bool, userID string, queries ...SearchQuery) (*User, error) {
if shouldTriggerBulk {
projection.UserProjection.TriggerBulk(ctx)
projection.LoginNameProjection.TriggerBulk(ctx)
projection.UserProjection.Trigger(ctx)
projection.LoginNameProjection.Trigger(ctx)
}
instanceID := authz.GetInstance(ctx).InstanceID()
@@ -317,8 +317,8 @@ func (q *Queries) GetUserByID(ctx context.Context, shouldTriggerBulk bool, userI
func (q *Queries) GetUser(ctx context.Context, shouldTriggerBulk bool, queries ...SearchQuery) (*User, error) {
if shouldTriggerBulk {
projection.UserProjection.TriggerBulk(ctx)
projection.LoginNameProjection.TriggerBulk(ctx)
projection.UserProjection.Trigger(ctx)
projection.LoginNameProjection.Trigger(ctx)
}
instanceID := authz.GetInstance(ctx).InstanceID()
@@ -390,8 +390,8 @@ func (q *Queries) GetHumanPhone(ctx context.Context, userID string, queries ...S
func (q *Queries) GeNotifyUser(ctx context.Context, shouldTriggered bool, userID string, queries ...SearchQuery) (*NotifyUser, error) {
if shouldTriggered {
projection.UserProjection.TriggerBulk(ctx)
projection.LoginNameProjection.TriggerBulk(ctx)
projection.UserProjection.Trigger(ctx)
projection.LoginNameProjection.Trigger(ctx)
}
instanceID := authz.GetInstance(ctx).InstanceID()

View File

@@ -193,7 +193,7 @@ var (
func (q *Queries) UserGrant(ctx context.Context, shouldTriggerBulk bool, queries ...SearchQuery) (*UserGrant, error) {
if shouldTriggerBulk {
projection.UserGrantProjection.TriggerBulk(ctx)
projection.UserGrantProjection.Trigger(ctx)
}
query, scan := prepareUserGrantQuery()

View File

@@ -73,7 +73,7 @@ var (
func (q *Queries) GetUserMetadataByKey(ctx context.Context, shouldTriggerBulk bool, userID, key string, queries ...SearchQuery) (*UserMetadata, error) {
if shouldTriggerBulk {
projection.UserMetadataProjection.TriggerBulk(ctx)
projection.UserMetadataProjection.Trigger(ctx)
}
query, scan := prepareUserMetadataQuery()
@@ -96,7 +96,7 @@ func (q *Queries) GetUserMetadataByKey(ctx context.Context, shouldTriggerBulk bo
func (q *Queries) SearchUserMetadata(ctx context.Context, shouldTriggerBulk bool, userID string, queries *UserMetadataSearchQueries) (*UserMetadataList, error) {
if shouldTriggerBulk {
projection.UserMetadataProjection.TriggerBulk(ctx)
projection.UserMetadataProjection.Trigger(ctx)
}
query, scan := prepareUserMetadataListQuery()

View File

@@ -82,7 +82,7 @@ type PersonalAccessTokenSearchQueries struct {
func (q *Queries) PersonalAccessTokenByID(ctx context.Context, shouldTriggerBulk bool, id string, queries ...SearchQuery) (*PersonalAccessToken, error) {
if shouldTriggerBulk {
projection.PersonalAccessTokenProjection.TriggerBulk(ctx)
projection.PersonalAccessTokenProjection.Trigger(ctx)
}
query, scan := preparePersonalAccessTokenQuery()

View File

@@ -55,8 +55,9 @@ func (key sequenceSearchKey) ToColumnName() string {
}
type sequenceSearchQuery struct {
key sequenceSearchKey
value string
key sequenceSearchKey
method domain.SearchMethod
value interface{}
}
func (q *sequenceSearchQuery) GetKey() ColumnKey {
@@ -64,7 +65,7 @@ func (q *sequenceSearchQuery) GetKey() ColumnKey {
}
func (q *sequenceSearchQuery) GetMethod() domain.SearchMethod {
return domain.SearchMethodEquals
return q.method
}
func (q *sequenceSearchQuery) GetValue() interface{} {
@@ -94,7 +95,7 @@ func (s *sequenceSearchRequest) GetAsc() bool {
func (s *sequenceSearchRequest) GetQueries() []SearchQuery {
result := make([]SearchQuery, len(s.queries))
for i, q := range s.queries {
result[i] = &sequenceSearchQuery{key: q.key, value: q.value}
result[i] = &sequenceSearchQuery{key: q.key, value: q.value, method: q.method}
}
return result
}
@@ -147,8 +148,8 @@ func UpdateCurrentSequences(db *gorm.DB, table string, currentSequences []*Curre
func LatestSequence(db *gorm.DB, table, viewName, instanceID string) (*CurrentSequence, error) {
searchQueries := []SearchQuery{
&sequenceSearchQuery{key: sequenceSearchKey(SequenceSearchKeyViewName), value: viewName},
&sequenceSearchQuery{key: sequenceSearchKey(SequenceSearchKeyInstanceID), value: instanceID},
&sequenceSearchQuery{key: sequenceSearchKey(SequenceSearchKeyViewName), value: viewName, method: domain.SearchMethodEquals},
&sequenceSearchQuery{key: sequenceSearchKey(SequenceSearchKeyInstanceID), value: instanceID, method: domain.SearchMethodIsOneOf},
}
// ensure highest sequence of view
@@ -168,13 +169,15 @@ func LatestSequence(db *gorm.DB, table, viewName, instanceID string) (*CurrentSe
return nil, caos_errs.ThrowInternalf(err, "VIEW-9LyCB", "unable to get latest sequence of %s", viewName)
}
func LatestSequences(db *gorm.DB, table, viewName string) ([]*CurrentSequence, error) {
searchQueries := make([]SearchQuery, 0, 2)
searchQueries = append(searchQueries)
func LatestSequences(db *gorm.DB, table, viewName string, instanceIDs ...string) ([]*CurrentSequence, error) {
searchQueries := []sequenceSearchQuery{
{key: sequenceSearchKey(SequenceSearchKeyViewName), value: viewName, method: domain.SearchMethodEquals},
}
if len(instanceIDs) > 0 {
searchQueries = append(searchQueries, sequenceSearchQuery{key: sequenceSearchKey(SequenceSearchKeyInstanceID), value: instanceIDs, method: domain.SearchMethodIsOneOf})
}
searchRequest := &sequenceSearchRequest{
queries: []sequenceSearchQuery{
{key: sequenceSearchKey(SequenceSearchKeyViewName), value: viewName},
},
queries: searchQueries,
}
// ensure highest sequence of view