Files
zitadel/internal/eventstore/handler/v2/handler.go
Tim Möhlmann 7ef95c9194 fix(db): wrap BeginTx in spans to get acquire metrics (#7689)
feat(db): wrap BeginTx in spans to get acquire metrics

This changes adds a span around most db.BeginTx calls so we can get tracings about the connection pool acquire process.
This might help us pinpoint why sometimes some query package traces show longer execution times, while this was not reflected on database side execution times.

Co-authored-by: Silvan <silvan.reusser@gmail.com>
(cherry picked from commit 093dd57a78)
2024-04-09 14:03:03 +02:00

677 lines
18 KiB
Go

package handler
import (
"context"
"database/sql"
"errors"
"math"
"math/rand"
"slices"
"sync"
"time"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/migration"
"github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/pseudo"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
type EventStore interface {
InstanceIDs(ctx context.Context, maxAge time.Duration, forceLoad bool, query *eventstore.SearchQueryBuilder) ([]string, error)
FilterToQueryReducer(ctx context.Context, reducer eventstore.QueryReducer) error
Filter(ctx context.Context, queryFactory *eventstore.SearchQueryBuilder) ([]eventstore.Event, error)
Push(ctx context.Context, cmds ...eventstore.Command) ([]eventstore.Event, error)
}
type Config struct {
Client *database.DB
Eventstore EventStore
BulkLimit uint16
RequeueEvery time.Duration
RetryFailedAfter time.Duration
HandleActiveInstances time.Duration
TransactionDuration time.Duration
MaxFailureCount uint8
TriggerWithoutEvents Reduce
}
type Handler struct {
client *database.DB
projection Projection
es EventStore
bulkLimit uint16
eventTypes map[eventstore.AggregateType][]eventstore.EventType
maxFailureCount uint8
retryFailedAfter time.Duration
requeueEvery time.Duration
handleActiveInstances time.Duration
txDuration time.Duration
now nowFunc
triggeredInstancesSync sync.Map
triggerWithoutEvents Reduce
}
var _ migration.Migration = (*Handler)(nil)
// Execute implements migration.Migration.
func (h *Handler) Execute(ctx context.Context, startedEvent eventstore.Event) error {
start := time.Now()
logging.WithFields("projection", h.ProjectionName()).Info("projection starts prefilling")
logTicker := time.NewTicker(30 * time.Second)
go func() {
for range logTicker.C {
logging.WithFields("projection", h.ProjectionName()).Info("projection is prefilling")
}
}()
instanceIDs, err := h.existingInstances(ctx)
if err != nil {
return err
}
// default amount of workers is 10
workerCount := 10
if h.client.DB.Stats().MaxOpenConnections > 0 {
workerCount = h.client.DB.Stats().MaxOpenConnections / 4
}
// ensure that at least one worker is active
if workerCount == 0 {
workerCount = 1
}
// spawn less workers if not all workers needed
if workerCount > len(instanceIDs) {
workerCount = len(instanceIDs)
}
instances := make(chan string, workerCount)
var wg sync.WaitGroup
wg.Add(workerCount)
for i := 0; i < workerCount; i++ {
go h.executeInstances(ctx, instances, startedEvent, &wg)
}
for _, instance := range instanceIDs {
instances <- instance
}
close(instances)
wg.Wait()
logTicker.Stop()
logging.WithFields("projection", h.ProjectionName(), "took", time.Since(start)).Info("projections ended prefilling")
return nil
}
func (h *Handler) executeInstances(ctx context.Context, instances <-chan string, startedEvent eventstore.Event, wg *sync.WaitGroup) {
for instance := range instances {
h.triggerInstances(ctx, []string{instance}, WithMaxPosition(startedEvent.Position()))
}
wg.Done()
}
// String implements migration.Migration.
func (h *Handler) String() string {
return h.ProjectionName()
}
// nowFunc makes [time.Now] mockable
type nowFunc func() time.Time
type Projection interface {
Name() string
Reducers() []AggregateReducer
}
func NewHandler(
ctx context.Context,
config *Config,
projection Projection,
) *Handler {
aggregates := make(map[eventstore.AggregateType][]eventstore.EventType, len(projection.Reducers()))
for _, reducer := range projection.Reducers() {
eventTypes := make([]eventstore.EventType, len(reducer.EventReducers))
for i, eventReducer := range reducer.EventReducers {
eventTypes[i] = eventReducer.Event
}
if _, ok := aggregates[reducer.Aggregate]; ok {
aggregates[reducer.Aggregate] = append(aggregates[reducer.Aggregate], eventTypes...)
continue
}
aggregates[reducer.Aggregate] = eventTypes
}
handler := &Handler{
projection: projection,
client: config.Client,
es: config.Eventstore,
bulkLimit: config.BulkLimit,
eventTypes: aggregates,
requeueEvery: config.RequeueEvery,
handleActiveInstances: config.HandleActiveInstances,
now: time.Now,
maxFailureCount: config.MaxFailureCount,
retryFailedAfter: config.RetryFailedAfter,
triggeredInstancesSync: sync.Map{},
triggerWithoutEvents: config.TriggerWithoutEvents,
txDuration: config.TransactionDuration,
}
return handler
}
func (h *Handler) Start(ctx context.Context) {
go h.schedule(ctx)
if h.triggerWithoutEvents != nil {
return
}
go h.subscribe(ctx)
}
type checkInit struct {
didInit bool
projectionName string
}
// AppendEvents implements eventstore.QueryReducer.
func (ci *checkInit) AppendEvents(...eventstore.Event) {
ci.didInit = true
}
// Query implements eventstore.QueryReducer.
func (ci *checkInit) Query() *eventstore.SearchQueryBuilder {
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
Limit(1).
InstanceID("").
AddQuery().
AggregateTypes(migration.SystemAggregate).
AggregateIDs(migration.SystemAggregateID).
EventTypes(migration.DoneType).
EventData(map[string]interface{}{
"name": ci.projectionName,
}).
Builder()
}
// Reduce implements eventstore.QueryReducer.
func (*checkInit) Reduce() error {
return nil
}
var _ eventstore.QueryReducer = (*checkInit)(nil)
func (h *Handler) didInitialize(ctx context.Context) bool {
initiated := checkInit{
projectionName: h.ProjectionName(),
}
err := h.es.FilterToQueryReducer(ctx, &initiated)
if err != nil {
return false
}
return initiated.didInit
}
func (h *Handler) schedule(ctx context.Context) {
// start the projection and its configured `RequeueEvery`
reset := randomizeStart(0, h.requeueEvery.Seconds())
if !h.didInitialize(ctx) {
reset = randomizeStart(0, 0.5)
}
t := time.NewTimer(reset)
for {
select {
case <-ctx.Done():
t.Stop()
return
case <-t.C:
instances, err := h.queryInstances(ctx)
h.log().OnError(err).Debug("unable to query instances")
h.triggerInstances(call.WithTimestamp(ctx), instances)
t.Reset(h.requeueEvery)
}
}
}
func (h *Handler) triggerInstances(ctx context.Context, instances []string, triggerOpts ...TriggerOpt) {
for _, instance := range instances {
instanceCtx := authz.WithInstanceID(ctx, instance)
// simple implementation of do while
_, err := h.Trigger(instanceCtx, triggerOpts...)
h.log().WithField("instance", instance).OnError(err).Debug("trigger failed")
time.Sleep(h.retryFailedAfter)
// retry if trigger failed
for ; err != nil; _, err = h.Trigger(instanceCtx, triggerOpts...) {
time.Sleep(h.retryFailedAfter)
h.log().WithField("instance", instance).OnError(err).Debug("trigger failed")
if err == nil {
break
}
}
}
}
func randomizeStart(min, maxSeconds float64) time.Duration {
d := min + rand.Float64()*(maxSeconds-min)
return time.Duration(d*1000) * time.Millisecond
}
func (h *Handler) subscribe(ctx context.Context) {
queue := make(chan eventstore.Event, 100)
subscription := eventstore.SubscribeEventTypes(queue, h.eventTypes)
for {
select {
case <-ctx.Done():
subscription.Unsubscribe()
h.log().Debug("shutdown")
return
case event := <-queue:
events := checkAdditionalEvents(queue, event)
solvedInstances := make([]string, 0, len(events))
queueCtx := call.WithTimestamp(ctx)
for _, e := range events {
if instanceSolved(solvedInstances, e.Aggregate().InstanceID) {
continue
}
queueCtx = authz.WithInstanceID(queueCtx, e.Aggregate().InstanceID)
_, err := h.Trigger(queueCtx)
h.log().OnError(err).Debug("trigger of queued event failed")
if err == nil {
solvedInstances = append(solvedInstances, e.Aggregate().InstanceID)
}
}
}
}
}
func instanceSolved(solvedInstances []string, instanceID string) bool {
for _, solvedInstance := range solvedInstances {
if solvedInstance == instanceID {
return true
}
}
return false
}
func checkAdditionalEvents(eventQueue chan eventstore.Event, event eventstore.Event) []eventstore.Event {
events := make([]eventstore.Event, 1)
events[0] = event
for {
wait := time.NewTimer(1 * time.Millisecond)
select {
case event := <-eventQueue:
events = append(events, event)
case <-wait.C:
return events
}
}
}
type existingInstances []string
// AppendEvents implements eventstore.QueryReducer.
func (ai *existingInstances) AppendEvents(events ...eventstore.Event) {
for _, event := range events {
switch event.Type() {
case instance.InstanceAddedEventType:
*ai = append(*ai, event.Aggregate().InstanceID)
case instance.InstanceRemovedEventType:
*ai = slices.DeleteFunc(*ai, func(s string) bool {
return s == event.Aggregate().InstanceID
})
}
}
}
// Query implements eventstore.QueryReducer.
func (*existingInstances) Query() *eventstore.SearchQueryBuilder {
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes(instance.AggregateType).
EventTypes(
instance.InstanceAddedEventType,
instance.InstanceRemovedEventType,
).
Builder()
}
// Reduce implements eventstore.QueryReducer.
// reduce is not used as events are reduced during AppendEvents
func (*existingInstances) Reduce() error {
return nil
}
var _ eventstore.QueryReducer = (*existingInstances)(nil)
func (h *Handler) queryInstances(ctx context.Context) ([]string, error) {
if h.handleActiveInstances == 0 {
return h.existingInstances(ctx)
}
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsInstanceIDs).
AwaitOpenTransactions().
AllowTimeTravel().
CreationDateAfter(h.now().Add(-1 * h.handleActiveInstances))
return h.es.InstanceIDs(ctx, h.requeueEvery, false, query)
}
func (h *Handler) existingInstances(ctx context.Context) ([]string, error) {
ai := existingInstances{}
if err := h.es.FilterToQueryReducer(ctx, &ai); err != nil {
return nil, err
}
return ai, nil
}
type triggerConfig struct {
awaitRunning bool
maxPosition float64
}
type TriggerOpt func(conf *triggerConfig)
func WithAwaitRunning() TriggerOpt {
return func(conf *triggerConfig) {
conf.awaitRunning = true
}
}
func WithMaxPosition(position float64) TriggerOpt {
return func(conf *triggerConfig) {
conf.maxPosition = position
}
}
func (h *Handler) Trigger(ctx context.Context, opts ...TriggerOpt) (_ context.Context, err error) {
config := new(triggerConfig)
for _, opt := range opts {
opt(config)
}
cancel := h.lockInstance(ctx, config)
if cancel == nil {
return call.ResetTimestamp(ctx), nil
}
defer cancel()
for i := 0; ; i++ {
additionalIteration, err := h.processEvents(ctx, config)
h.log().OnError(err).Info("process events failed")
h.log().WithField("iteration", i).Debug("trigger iteration")
if !additionalIteration || err != nil {
return call.ResetTimestamp(ctx), err
}
}
}
// lockInstances tries to lock the instance.
// If the instance is already locked from another process no cancel function is returned
// the instance can be skipped then
// If the instance is locked, an unlock deferrable function is returned
func (h *Handler) lockInstance(ctx context.Context, config *triggerConfig) func() {
instanceID := authz.GetInstance(ctx).InstanceID()
// Check that the instance has a lock
instanceLock, _ := h.triggeredInstancesSync.LoadOrStore(instanceID, make(chan bool, 1))
// in case we don't want to wait for a running trigger / lock (e.g. spooler),
// we can directly return if we cannot lock
if !config.awaitRunning {
select {
case instanceLock.(chan bool) <- true:
return func() {
<-instanceLock.(chan bool)
}
default:
return nil
}
}
// in case we want to wait for a running trigger / lock (e.g. query),
// we try to lock as long as the context is not cancelled
select {
case instanceLock.(chan bool) <- true:
return func() {
<-instanceLock.(chan bool)
}
case <-ctx.Done():
return nil
}
}
func (h *Handler) processEvents(ctx context.Context, config *triggerConfig) (additionalIteration bool, err error) {
defer func() {
pgErr := new(pgconn.PgError)
if errors.As(err, &pgErr) {
// error returned if the row is currently locked by another connection
if pgErr.Code == "55P03" {
h.log().Debug("state already locked")
err = nil
additionalIteration = false
}
}
}()
txCtx := ctx
if h.txDuration > 0 {
var cancel, cancelTx func()
// add 100ms to store current state if iteration takes too long
txCtx, cancelTx = context.WithTimeout(ctx, h.txDuration+100*time.Millisecond)
defer cancelTx()
ctx, cancel = context.WithTimeout(ctx, h.txDuration)
defer cancel()
}
ctx, spanBeginTx := tracing.NewNamedSpan(ctx, "db.BeginTx")
tx, err := h.client.BeginTx(txCtx, nil)
spanBeginTx.EndWithError(err)
if err != nil {
return false, err
}
defer func() {
if err != nil && !errors.Is(err, &executionError{}) {
rollbackErr := tx.Rollback()
h.log().OnError(rollbackErr).Debug("unable to rollback tx")
return
}
commitErr := tx.Commit()
if err == nil {
err = commitErr
}
}()
currentState, err := h.currentState(ctx, tx, config)
if err != nil {
if errors.Is(err, errJustUpdated) {
return false, nil
}
return additionalIteration, err
}
// stop execution if currentState.eventTimestamp >= config.maxCreatedAt
if config.maxPosition != 0 && currentState.position >= config.maxPosition {
return false, nil
}
var statements []*Statement
statements, additionalIteration, err = h.generateStatements(ctx, tx, currentState)
if err != nil {
return additionalIteration, err
}
if len(statements) == 0 {
err = h.setState(tx, currentState)
return additionalIteration, err
}
lastProcessedIndex, err := h.executeStatements(ctx, tx, currentState, statements)
h.log().OnError(err).WithField("lastProcessedIndex", lastProcessedIndex).Debug("execution of statements failed")
if lastProcessedIndex < 0 {
return false, err
}
currentState.position = statements[lastProcessedIndex].Position
currentState.offset = statements[lastProcessedIndex].offset
currentState.aggregateID = statements[lastProcessedIndex].AggregateID
currentState.aggregateType = statements[lastProcessedIndex].AggregateType
currentState.sequence = statements[lastProcessedIndex].Sequence
currentState.eventTimestamp = statements[lastProcessedIndex].CreationDate
err = h.setState(tx, currentState)
return additionalIteration, err
}
func (h *Handler) generateStatements(ctx context.Context, tx *sql.Tx, currentState *state) (_ []*Statement, additionalIteration bool, err error) {
if h.triggerWithoutEvents != nil {
stmt, err := h.triggerWithoutEvents(pseudo.NewScheduledEvent(ctx, time.Now(), currentState.instanceID))
if err != nil {
return nil, false, err
}
return []*Statement{stmt}, false, nil
}
events, err := h.es.Filter(ctx, h.eventQuery(currentState))
if err != nil {
h.log().WithError(err).Debug("filter eventstore failed")
return nil, false, err
}
eventAmount := len(events)
statements, err := h.eventsToStatements(tx, events, currentState)
if err != nil || len(statements) == 0 {
return nil, false, err
}
idx := skipPreviouslyReduced(statements, currentState)
if idx+1 == len(statements) {
currentState.position = statements[len(statements)-1].Position
currentState.offset = statements[len(statements)-1].offset
currentState.aggregateID = statements[len(statements)-1].AggregateID
currentState.aggregateType = statements[len(statements)-1].AggregateType
currentState.sequence = statements[len(statements)-1].Sequence
currentState.eventTimestamp = statements[len(statements)-1].CreationDate
return nil, false, nil
}
statements = statements[idx+1:]
additionalIteration = eventAmount == int(h.bulkLimit)
if len(statements) < len(events) {
// retry immediately if statements failed
additionalIteration = true
}
return statements, additionalIteration, nil
}
func skipPreviouslyReduced(statements []*Statement, currentState *state) int {
for i, statement := range statements {
if statement.Position == currentState.position &&
statement.AggregateID == currentState.aggregateID &&
statement.AggregateType == currentState.aggregateType &&
statement.Sequence == currentState.sequence {
return i
}
}
return -1
}
func (h *Handler) executeStatements(ctx context.Context, tx *sql.Tx, currentState *state, statements []*Statement) (lastProcessedIndex int, err error) {
lastProcessedIndex = -1
for i, statement := range statements {
select {
case <-ctx.Done():
break
default:
err := h.executeStatement(ctx, tx, currentState, statement)
if err != nil {
return lastProcessedIndex, err
}
lastProcessedIndex = i
}
}
return lastProcessedIndex, nil
}
func (h *Handler) executeStatement(ctx context.Context, tx *sql.Tx, currentState *state, statement *Statement) (err error) {
if statement.Execute == nil {
return nil
}
_, err = tx.Exec("SAVEPOINT exec")
if err != nil {
h.log().WithError(err).Debug("create savepoint failed")
return err
}
var shouldContinue bool
defer func() {
_, errSave := tx.Exec("RELEASE SAVEPOINT exec")
if err == nil {
err = errSave
}
}()
if err = statement.Execute(tx, h.projection.Name()); err != nil {
h.log().WithError(err).Error("statement execution failed")
shouldContinue = h.handleFailedStmt(tx, failureFromStatement(statement, err))
if shouldContinue {
return nil
}
return &executionError{parent: err}
}
return nil
}
func (h *Handler) eventQuery(currentState *state) *eventstore.SearchQueryBuilder {
builder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AwaitOpenTransactions().
Limit(uint64(h.bulkLimit)).
AllowTimeTravel().
OrderAsc().
InstanceID(currentState.instanceID)
if currentState.position > 0 {
// decrease position by 10 because builder.PositionAfter filters for position > and we need position >=
builder = builder.PositionAfter(math.Float64frombits(math.Float64bits(currentState.position) - 10))
if currentState.offset > 0 {
builder = builder.Offset(currentState.offset)
}
}
for aggregateType, eventTypes := range h.eventTypes {
query := builder.
AddQuery().
AggregateTypes(aggregateType).
EventTypes(eventTypes...)
builder = query.Builder()
}
return builder
}
// ProjectionName returns the name of the underlying projection.
func (h *Handler) ProjectionName() string {
return h.projection.Name()
}