package crdb import ( "context" "database/sql" "fmt" "strconv" "strings" "time" "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/id" ) const ( lockStmtFormat = "INSERT INTO %[1]s" + " (locker_id, locked_until, projection_name, instance_id) VALUES %[2]s" + " ON CONFLICT (projection_name, instance_id)" + " DO UPDATE SET locker_id = $1, locked_until = now()+$2::INTERVAL" + " WHERE %[1]s.projection_name = $3 AND %[1]s.instance_id = ANY ($%[3]d) AND (%[1]s.locker_id = $1 OR %[1]s.locked_until < now())" ) type Locker interface { Lock(ctx context.Context, lockDuration time.Duration, instanceIDs ...string) <-chan error Unlock(instanceIDs ...string) error } type locker struct { client *sql.DB lockStmt func(values string, instances int) string workerName string projectionName string } func NewLocker(client *sql.DB, lockTable, projectionName string) Locker { workerName, err := id.SonyFlakeGenerator().Next() logging.OnError(err).Panic("unable to generate lockID") return &locker{ client: client, lockStmt: func(values string, instances int) string { return fmt.Sprintf(lockStmtFormat, lockTable, values, instances) }, workerName: workerName, projectionName: projectionName, } } func (h *locker) Lock(ctx context.Context, lockDuration time.Duration, instanceIDs ...string) <-chan error { errs := make(chan error) go h.handleLock(ctx, errs, lockDuration, instanceIDs...) return errs } func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration time.Duration, instanceIDs ...string) { renewLock := time.NewTimer(0) for { select { case <-renewLock.C: errs <- h.renewLock(ctx, lockDuration, instanceIDs...) //refresh the lock 500ms before it times out. 500ms should be enough for one transaction renewLock.Reset(lockDuration - (500 * time.Millisecond)) case <-ctx.Done(): close(errs) renewLock.Stop() return } } } func (h *locker) renewLock(ctx context.Context, lockDuration time.Duration, instanceIDs ...string) error { lockStmt, values := h.lockStatement(lockDuration, instanceIDs) res, err := h.client.ExecContext(ctx, lockStmt, values...) if err != nil { return errors.ThrowInternal(err, "CRDB-uaDoR", "unable to execute lock") } if rows, _ := res.RowsAffected(); rows == 0 { return errors.ThrowAlreadyExists(nil, "CRDB-mmi4J", "projection already locked") } return nil } func (h *locker) Unlock(instanceIDs ...string) error { lockStmt, values := h.lockStatement(0, instanceIDs) _, err := h.client.Exec(lockStmt, values...) if err != nil { return errors.ThrowUnknown(err, "CRDB-JjfwO", "unlock failed") } return nil } func (h *locker) lockStatement(lockDuration time.Duration, instanceIDs database.StringArray) (string, []interface{}) { valueQueries := make([]string, len(instanceIDs)) values := make([]interface{}, len(instanceIDs)+4) values[0] = h.workerName //the unit of crdb interval is seconds (https://www.cockroachlabs.com/docs/stable/interval.html). values[1] = lockDuration values[2] = h.projectionName for i, instanceID := range instanceIDs { valueQueries[i] = "($1, now()+$2::INTERVAL, $3, $" + strconv.Itoa(i+4) + ")" values[i+3] = instanceID } values[len(values)-1] = instanceIDs return h.lockStmt(strings.Join(valueQueries, ", "), len(values)), values }