mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 10:49:25 +00:00
chore: move the go code into a subfolder
This commit is contained in:
99
apps/api/internal/migration/command.go
Normal file
99
apps/api/internal/migration/command.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/service"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func init() {
|
||||
eventstore.RegisterFilterEventMapper(SystemAggregate, StartedType, SetupMapper)
|
||||
eventstore.RegisterFilterEventMapper(SystemAggregate, DoneType, SetupMapper)
|
||||
eventstore.RegisterFilterEventMapper(SystemAggregate, failedType, SetupMapper)
|
||||
eventstore.RegisterFilterEventMapper(SystemAggregate, repeatableDoneType, SetupMapper)
|
||||
}
|
||||
|
||||
// SetupStep is the command pushed on the eventstore
|
||||
type SetupStep struct {
|
||||
eventstore.BaseEvent `json:"-"`
|
||||
migration Migration
|
||||
Name string `json:"name"`
|
||||
Error any `json:"error,omitempty"`
|
||||
LastRun any `json:"lastRun,omitempty"`
|
||||
}
|
||||
|
||||
func setupStartedCmd(ctx context.Context, migration Migration) eventstore.Command {
|
||||
ctx = authz.SetCtxData(service.WithService(ctx, "system"), authz.CtxData{UserID: "system", OrgID: "SYSTEM", ResourceOwner: "SYSTEM"})
|
||||
return &SetupStep{
|
||||
BaseEvent: *eventstore.NewBaseEventForPush(
|
||||
ctx,
|
||||
eventstore.NewAggregate(ctx, SystemAggregateID, SystemAggregate, "v1"),
|
||||
StartedType),
|
||||
migration: migration,
|
||||
Name: migration.String(),
|
||||
}
|
||||
}
|
||||
|
||||
func setupDoneCmd(ctx context.Context, migration Migration, err error) eventstore.Command {
|
||||
ctx = authz.SetCtxData(service.WithService(ctx, "system"), authz.CtxData{UserID: "system", OrgID: "SYSTEM", ResourceOwner: "SYSTEM"})
|
||||
typ := DoneType
|
||||
var lastRun interface{}
|
||||
if repeatable, ok := migration.(RepeatableMigration); ok {
|
||||
typ = repeatableDoneType
|
||||
lastRun = repeatable
|
||||
}
|
||||
|
||||
s := &SetupStep{
|
||||
migration: migration,
|
||||
Name: migration.String(),
|
||||
LastRun: lastRun,
|
||||
}
|
||||
if err != nil {
|
||||
typ = failedType
|
||||
s.Error = err.Error()
|
||||
}
|
||||
|
||||
s.BaseEvent = *eventstore.NewBaseEventForPush(
|
||||
ctx,
|
||||
eventstore.NewAggregate(ctx, SystemAggregateID, SystemAggregate, "v1"),
|
||||
typ)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *SetupStep) Payload() interface{} {
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *SetupStep) UniqueConstraints() []*eventstore.UniqueConstraint {
|
||||
switch s.Type() {
|
||||
case StartedType:
|
||||
return []*eventstore.UniqueConstraint{
|
||||
eventstore.NewAddGlobalUniqueConstraint("migration_started", s.migration.String(), "Errors.Step.Started.AlreadyExists"),
|
||||
}
|
||||
case failedType,
|
||||
repeatableDoneType:
|
||||
return []*eventstore.UniqueConstraint{
|
||||
eventstore.NewRemoveGlobalUniqueConstraint("migration_started", s.migration.String()),
|
||||
}
|
||||
default:
|
||||
return []*eventstore.UniqueConstraint{
|
||||
eventstore.NewAddGlobalUniqueConstraint("migration_done", s.migration.String(), "Errors.Step.Done.AlreadyExists"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func SetupMapper(event eventstore.Event) (eventstore.Event, error) {
|
||||
step := &SetupStep{
|
||||
BaseEvent: *eventstore.BaseEventFromRepo(event),
|
||||
}
|
||||
err := event.Unmarshal(step)
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "IAM-hYp7M", "unable to unmarshal step")
|
||||
}
|
||||
|
||||
return step, nil
|
||||
}
|
43
apps/api/internal/migration/count_trigger.sql
Normal file
43
apps/api/internal/migration/count_trigger.sql
Normal file
@@ -0,0 +1,43 @@
|
||||
{{ define "count_trigger" -}}
|
||||
CREATE OR REPLACE TRIGGER count_{{ .Resource }}
|
||||
AFTER INSERT OR DELETE
|
||||
ON {{ .Table }}
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION projections.count_resource(
|
||||
'{{ .ParentType }}',
|
||||
'{{ .InstanceIDColumn }}',
|
||||
'{{ .ParentIDColumn }}',
|
||||
'{{ .Resource }}'
|
||||
);
|
||||
|
||||
CREATE OR REPLACE TRIGGER truncate_{{ .Resource }}_counts
|
||||
AFTER TRUNCATE
|
||||
ON {{ .Table }}
|
||||
FOR EACH STATEMENT
|
||||
EXECUTE FUNCTION projections.delete_table_counts();
|
||||
|
||||
-- Prevent inserts and deletes while we populate the counts.
|
||||
LOCK TABLE {{ .Table }} IN SHARE MODE;
|
||||
|
||||
-- Populate the resource counts for the existing data in the table.
|
||||
INSERT INTO projections.resource_counts(
|
||||
instance_id,
|
||||
table_name,
|
||||
parent_type,
|
||||
parent_id,
|
||||
resource_name,
|
||||
amount
|
||||
)
|
||||
SELECT
|
||||
{{ .InstanceIDColumn }},
|
||||
'{{ .Table }}',
|
||||
'{{ .ParentType }}',
|
||||
{{ .ParentIDColumn }},
|
||||
'{{ .Resource }}',
|
||||
COUNT(*) AS amount
|
||||
FROM {{ .Table }}
|
||||
GROUP BY ({{ .InstanceIDColumn }}, {{ .ParentIDColumn }})
|
||||
ON CONFLICT (instance_id, table_name, parent_type, parent_id) DO
|
||||
UPDATE SET updated_at = now(), amount = EXCLUDED.amount;
|
||||
|
||||
{{- end -}}
|
13
apps/api/internal/migration/delete_parent_counts_trigger.sql
Normal file
13
apps/api/internal/migration/delete_parent_counts_trigger.sql
Normal file
@@ -0,0 +1,13 @@
|
||||
{{ define "delete_parent_counts_trigger" -}}
|
||||
|
||||
CREATE OR REPLACE TRIGGER delete_parent_counts_trigger
|
||||
AFTER DELETE
|
||||
ON {{ .Table }}
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION projections.delete_parent_counts(
|
||||
'{{ .ParentType }}',
|
||||
'{{ .InstanceIDColumn }}',
|
||||
'{{ .ParentIDColumn }}'
|
||||
);
|
||||
|
||||
{{- end -}}
|
165
apps/api/internal/migration/migration.go
Normal file
165
apps/api/internal/migration/migration.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
const (
|
||||
StartedType = eventstore.EventType("system.migration.started")
|
||||
DoneType = eventstore.EventType("system.migration.done")
|
||||
failedType = eventstore.EventType("system.migration.failed")
|
||||
repeatableDoneType = eventstore.EventType("system.migration.repeatable.done")
|
||||
SystemAggregate = eventstore.AggregateType("system")
|
||||
SystemAggregateID = "SYSTEM"
|
||||
)
|
||||
|
||||
var (
|
||||
errMigrationAlreadyStarted = errors.New("already started")
|
||||
)
|
||||
|
||||
type Migration interface {
|
||||
String() string
|
||||
Execute(ctx context.Context, startedEvent eventstore.Event) error
|
||||
}
|
||||
|
||||
type errCheckerMigration interface {
|
||||
Migration
|
||||
ContinueOnErr(err error) bool
|
||||
}
|
||||
|
||||
type RepeatableMigration interface {
|
||||
Migration
|
||||
|
||||
// Check if the migration should be executed again.
|
||||
// True will repeat the migration, false will not.
|
||||
Check(lastRun map[string]any) bool
|
||||
}
|
||||
|
||||
func Migrate(ctx context.Context, es *eventstore.Eventstore, migration Migration) (err error) {
|
||||
logging.WithFields("name", migration.String()).Info("verify migration")
|
||||
|
||||
continueOnErr := func(err error) bool {
|
||||
return false
|
||||
}
|
||||
errChecker, ok := migration.(errCheckerMigration)
|
||||
if ok {
|
||||
continueOnErr = errChecker.ContinueOnErr
|
||||
}
|
||||
|
||||
should, err := checkExec(ctx, es, migration)
|
||||
if err != nil && !continueOnErr(err) {
|
||||
return err
|
||||
}
|
||||
if !should {
|
||||
return nil
|
||||
}
|
||||
|
||||
startedEvent, err := es.Push(ctx, setupStartedCmd(ctx, migration))
|
||||
if err != nil && !continueOnErr(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
logging.WithFields("name", migration.String()).Info("starting migration")
|
||||
err = migration.Execute(ctx, startedEvent[0])
|
||||
logging.WithFields("name", migration.String()).OnError(err).Error("migration failed")
|
||||
|
||||
_, pushErr := es.Push(ctx, setupDoneCmd(ctx, migration, err))
|
||||
logging.WithFields("name", migration.String()).OnError(pushErr).Error("migration finish failed")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return pushErr
|
||||
}
|
||||
|
||||
func LastStuckStep(ctx context.Context, es *eventstore.Eventstore) (*SetupStep, error) {
|
||||
var states StepStates
|
||||
err := es.FilterToQueryReducer(ctx, &states)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
step := states.lastByState(StepStarted)
|
||||
if step == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return step.SetupStep, nil
|
||||
}
|
||||
|
||||
var _ Migration = (*cancelMigration)(nil)
|
||||
|
||||
type cancelMigration struct {
|
||||
name string
|
||||
}
|
||||
|
||||
// Execute implements Migration
|
||||
func (*cancelMigration) Execute(context.Context, eventstore.Event) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// String implements Migration
|
||||
func (m *cancelMigration) String() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
var errCancelStep = zerrors.ThrowError(nil, "MIGRA-zo86K", "migration canceled manually")
|
||||
|
||||
func CancelStep(ctx context.Context, es *eventstore.Eventstore, step *SetupStep) error {
|
||||
_, err := es.Push(ctx, setupDoneCmd(ctx, &cancelMigration{name: step.Name}, errCancelStep))
|
||||
return err
|
||||
}
|
||||
|
||||
// checkExec ensures that only one setup step is done concurrently
|
||||
// if a setup step is already started, it calls shouldExec after some time again
|
||||
func checkExec(ctx context.Context, es *eventstore.Eventstore, migration Migration) (bool, error) {
|
||||
timer := time.NewTimer(0)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false, zerrors.ThrowInternal(nil, "MIGR-as3f7", "Errors.Internal")
|
||||
case <-timer.C:
|
||||
should, err := shouldExec(ctx, es, migration)
|
||||
if err != nil {
|
||||
if !errors.Is(err, errMigrationAlreadyStarted) {
|
||||
return false, err
|
||||
}
|
||||
logging.WithFields("migration step", migration.String()).
|
||||
Warn("migration already started, will check again in 5 seconds")
|
||||
timer.Reset(5 * time.Second)
|
||||
break
|
||||
}
|
||||
return should, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func shouldExec(ctx context.Context, es *eventstore.Eventstore, migration Migration) (should bool, err error) {
|
||||
var states StepStates
|
||||
err = es.FilterToQueryReducer(ctx, &states)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
step := states.byName(migration.String())
|
||||
if step == nil {
|
||||
return true, nil
|
||||
}
|
||||
if step.state == StepFailed {
|
||||
return true, nil
|
||||
}
|
||||
if step.state == StepStarted {
|
||||
return false, errMigrationAlreadyStarted
|
||||
}
|
||||
|
||||
repeatable, ok := migration.(RepeatableMigration)
|
||||
if !ok {
|
||||
return step.state != StepDone, nil
|
||||
}
|
||||
lastRun, _ := step.LastRun.(map[string]interface{})
|
||||
return repeatable.Check(lastRun), nil
|
||||
}
|
86
apps/api/internal/migration/step.go
Normal file
86
apps/api/internal/migration/step.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package migration
|
||||
|
||||
import "github.com/zitadel/zitadel/internal/eventstore"
|
||||
|
||||
var _ eventstore.QueryReducer = (*StepStates)(nil)
|
||||
|
||||
type Step struct {
|
||||
*SetupStep
|
||||
|
||||
state StepState
|
||||
}
|
||||
|
||||
type StepStates struct {
|
||||
eventstore.ReadModel
|
||||
Steps []*Step
|
||||
}
|
||||
|
||||
// Query implements eventstore.QueryReducer.
|
||||
func (*StepStates) Query() *eventstore.SearchQueryBuilder {
|
||||
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
InstanceID(""). // to make sure we can use an appropriate index
|
||||
AddQuery().
|
||||
AggregateTypes(SystemAggregate).
|
||||
AggregateIDs(SystemAggregateID).
|
||||
EventTypes(StartedType, DoneType, repeatableDoneType, failedType).
|
||||
Builder()
|
||||
}
|
||||
|
||||
// Reduce implements eventstore.QueryReducer.
|
||||
func (s *StepStates) Reduce() error {
|
||||
for _, event := range s.Events {
|
||||
step := event.(*SetupStep)
|
||||
state := s.byName(step.Name)
|
||||
if state == nil {
|
||||
state = new(Step)
|
||||
s.Steps = append(s.Steps, state)
|
||||
}
|
||||
state.SetupStep = step
|
||||
switch step.EventType {
|
||||
case StartedType:
|
||||
state.state = StepStarted
|
||||
case DoneType:
|
||||
state.state = StepDone
|
||||
case repeatableDoneType:
|
||||
state.state = StepDone
|
||||
case failedType:
|
||||
state.state = StepFailed
|
||||
}
|
||||
}
|
||||
return s.ReadModel.Reduce()
|
||||
}
|
||||
|
||||
func (s *StepStates) byName(name string) *Step {
|
||||
for _, step := range s.Steps {
|
||||
if step.Name == name {
|
||||
return step
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StepStates) lastByState(stepState StepState) (step *Step) {
|
||||
for _, state := range s.Steps {
|
||||
if state.state != stepState {
|
||||
continue
|
||||
}
|
||||
if step == nil {
|
||||
step = state
|
||||
continue
|
||||
}
|
||||
if step.CreatedAt().After(state.CreatedAt()) {
|
||||
continue
|
||||
}
|
||||
|
||||
step = state
|
||||
}
|
||||
return step
|
||||
}
|
||||
|
||||
type StepState int32
|
||||
|
||||
const (
|
||||
StepStarted StepState = iota
|
||||
StepDone
|
||||
StepFailed
|
||||
)
|
94
apps/api/internal/migration/step_test.go
Normal file
94
apps/api/internal/migration/step_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
func TestStepStates_lastByState(t *testing.T) {
|
||||
now := time.Now()
|
||||
past := now.Add(-10 * time.Millisecond)
|
||||
tests := []struct {
|
||||
name string
|
||||
fields *StepStates
|
||||
arg StepState
|
||||
want *Step
|
||||
}{
|
||||
{
|
||||
name: "no events reduced invalid state",
|
||||
fields: &StepStates{},
|
||||
arg: -1,
|
||||
},
|
||||
{
|
||||
name: "no events reduced by valid state",
|
||||
fields: &StepStates{},
|
||||
arg: StepDone,
|
||||
},
|
||||
{
|
||||
name: "no state found",
|
||||
fields: &StepStates{
|
||||
Steps: []*Step{
|
||||
{
|
||||
SetupStep: &SetupStep{
|
||||
Name: "done",
|
||||
},
|
||||
state: StepDone,
|
||||
},
|
||||
{
|
||||
SetupStep: &SetupStep{
|
||||
Name: "failed",
|
||||
},
|
||||
state: StepFailed,
|
||||
},
|
||||
},
|
||||
},
|
||||
arg: StepStarted,
|
||||
},
|
||||
{
|
||||
name: "found",
|
||||
fields: &StepStates{
|
||||
Steps: []*Step{
|
||||
{
|
||||
SetupStep: &SetupStep{
|
||||
BaseEvent: eventstore.BaseEvent{
|
||||
Creation: past,
|
||||
},
|
||||
},
|
||||
state: StepStarted,
|
||||
},
|
||||
{
|
||||
SetupStep: &SetupStep{
|
||||
BaseEvent: eventstore.BaseEvent{
|
||||
Creation: now,
|
||||
},
|
||||
},
|
||||
state: StepStarted,
|
||||
},
|
||||
},
|
||||
},
|
||||
arg: StepStarted,
|
||||
want: &Step{
|
||||
state: StepStarted,
|
||||
SetupStep: &SetupStep{
|
||||
BaseEvent: eventstore.BaseEvent{
|
||||
Creation: now,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &StepStates{
|
||||
ReadModel: tt.fields.ReadModel,
|
||||
Steps: tt.fields.Steps,
|
||||
}
|
||||
if gotStep := s.lastByState(tt.arg); !reflect.DeepEqual(gotStep, tt.want) {
|
||||
t.Errorf("StepStates.lastByState() = %v, want %v", *gotStep, *tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
127
apps/api/internal/migration/trigger.go
Normal file
127
apps/api/internal/migration/trigger.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"fmt"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
const (
|
||||
countTriggerTmpl = "count_trigger"
|
||||
deleteParentCountsTmpl = "delete_parent_counts_trigger"
|
||||
)
|
||||
|
||||
var (
|
||||
//go:embed *.sql
|
||||
templateFS embed.FS
|
||||
templates = template.Must(template.ParseFS(templateFS, "*.sql"))
|
||||
)
|
||||
|
||||
// CountTrigger registers the existing projections.count_trigger function.
|
||||
// The trigger than takes care of keeping count of existing
|
||||
// rows in the source table.
|
||||
// It also pre-populates the projections.resource_counts table with
|
||||
// the counts for the given table.
|
||||
//
|
||||
// During the population of the resource_counts table,
|
||||
// the source table is share-locked to prevent concurrent modifications.
|
||||
// Projection handlers will be halted until the lock is released.
|
||||
// SELECT statements are not blocked by the lock.
|
||||
//
|
||||
// This migration repeats when any of the arguments are changed,
|
||||
// such as renaming of a projection table.
|
||||
func CountTrigger(
|
||||
db *database.DB,
|
||||
table string,
|
||||
parentType domain.CountParentType,
|
||||
instanceIDColumn string,
|
||||
parentIDColumn string,
|
||||
resource string,
|
||||
) RepeatableMigration {
|
||||
return &triggerMigration{
|
||||
triggerConfig: triggerConfig{
|
||||
Table: table,
|
||||
ParentType: parentType.String(),
|
||||
InstanceIDColumn: instanceIDColumn,
|
||||
ParentIDColumn: parentIDColumn,
|
||||
Resource: resource,
|
||||
},
|
||||
db: db,
|
||||
templateName: countTriggerTmpl,
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteParentCountsTrigger
|
||||
//
|
||||
// This migration repeats when any of the arguments are changed,
|
||||
// such as renaming of a projection table.
|
||||
func DeleteParentCountsTrigger(
|
||||
db *database.DB,
|
||||
table string,
|
||||
parentType domain.CountParentType,
|
||||
instanceIDColumn string,
|
||||
parentIDColumn string,
|
||||
resource string,
|
||||
) RepeatableMigration {
|
||||
return &triggerMigration{
|
||||
triggerConfig: triggerConfig{
|
||||
Table: table,
|
||||
ParentType: parentType.String(),
|
||||
InstanceIDColumn: instanceIDColumn,
|
||||
ParentIDColumn: parentIDColumn,
|
||||
Resource: resource,
|
||||
},
|
||||
db: db,
|
||||
templateName: deleteParentCountsTmpl,
|
||||
}
|
||||
}
|
||||
|
||||
type triggerMigration struct {
|
||||
triggerConfig
|
||||
db *database.DB
|
||||
templateName string
|
||||
}
|
||||
|
||||
// String implements [Migration] and [fmt.Stringer].
|
||||
func (m *triggerMigration) String() string {
|
||||
return fmt.Sprintf("repeatable_%s_%s", m.Resource, m.templateName)
|
||||
}
|
||||
|
||||
// Execute implements [Migration]
|
||||
func (m *triggerMigration) Execute(ctx context.Context, _ eventstore.Event) error {
|
||||
var query strings.Builder
|
||||
err := templates.ExecuteTemplate(&query, m.templateName, m.triggerConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: execute trigger template: %w", m, err)
|
||||
}
|
||||
_, err = m.db.ExecContext(ctx, query.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: exec trigger query: %w", m, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type triggerConfig struct {
|
||||
Table string `json:"table,omitempty" mapstructure:"table"`
|
||||
ParentType string `json:"parent_type,omitempty" mapstructure:"parent_type"`
|
||||
InstanceIDColumn string `json:"instance_id_column,omitempty" mapstructure:"instance_id_column"`
|
||||
ParentIDColumn string `json:"parent_id_column,omitempty" mapstructure:"parent_id_column"`
|
||||
Resource string `json:"resource,omitempty" mapstructure:"resource"`
|
||||
}
|
||||
|
||||
// Check implements [RepeatableMigration].
|
||||
func (c *triggerConfig) Check(lastRun map[string]any) bool {
|
||||
var dst triggerConfig
|
||||
if err := mapstructure.Decode(lastRun, &dst); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return dst != *c
|
||||
}
|
253
apps/api/internal/migration/trigger_test.go
Normal file
253
apps/api/internal/migration/trigger_test.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
)
|
||||
|
||||
const (
|
||||
expCountTriggerQuery = `CREATE OR REPLACE TRIGGER count_resource
|
||||
AFTER INSERT OR DELETE
|
||||
ON table
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION projections.count_resource(
|
||||
'instance',
|
||||
'instance_id',
|
||||
'parent_id',
|
||||
'resource'
|
||||
);
|
||||
|
||||
CREATE OR REPLACE TRIGGER truncate_resource_counts
|
||||
AFTER TRUNCATE
|
||||
ON table
|
||||
FOR EACH STATEMENT
|
||||
EXECUTE FUNCTION projections.delete_table_counts();
|
||||
|
||||
-- Prevent inserts and deletes while we populate the counts.
|
||||
LOCK TABLE table IN SHARE MODE;
|
||||
|
||||
-- Populate the resource counts for the existing data in the table.
|
||||
INSERT INTO projections.resource_counts(
|
||||
instance_id,
|
||||
table_name,
|
||||
parent_type,
|
||||
parent_id,
|
||||
resource_name,
|
||||
amount
|
||||
)
|
||||
SELECT
|
||||
instance_id,
|
||||
'table',
|
||||
'instance',
|
||||
parent_id,
|
||||
'resource',
|
||||
COUNT(*) AS amount
|
||||
FROM table
|
||||
GROUP BY (instance_id, parent_id)
|
||||
ON CONFLICT (instance_id, table_name, parent_type, parent_id) DO
|
||||
UPDATE SET updated_at = now(), amount = EXCLUDED.amount;`
|
||||
|
||||
expDeleteParentCountsQuery = `CREATE OR REPLACE TRIGGER delete_parent_counts_trigger
|
||||
AFTER DELETE
|
||||
ON table
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION projections.delete_parent_counts(
|
||||
'instance',
|
||||
'instance_id',
|
||||
'parent_id'
|
||||
);`
|
||||
)
|
||||
|
||||
func Test_triggerMigration_Execute(t *testing.T) {
|
||||
type fields struct {
|
||||
triggerConfig triggerConfig
|
||||
templateName string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
expects func(sqlmock.Sqlmock)
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "template error",
|
||||
fields: fields{
|
||||
triggerConfig: triggerConfig{
|
||||
Table: "table",
|
||||
ParentType: "instance",
|
||||
InstanceIDColumn: "instance_id",
|
||||
ParentIDColumn: "parent_id",
|
||||
Resource: "resource",
|
||||
},
|
||||
templateName: "foo",
|
||||
},
|
||||
expects: func(_ sqlmock.Sqlmock) {},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "db error",
|
||||
fields: fields{
|
||||
triggerConfig: triggerConfig{
|
||||
Table: "table",
|
||||
ParentType: "instance",
|
||||
InstanceIDColumn: "instance_id",
|
||||
ParentIDColumn: "parent_id",
|
||||
Resource: "resource",
|
||||
},
|
||||
templateName: countTriggerTmpl,
|
||||
},
|
||||
expects: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectExec(regexp.QuoteMeta(expCountTriggerQuery)).
|
||||
WillReturnError(assert.AnError)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "count trigger",
|
||||
fields: fields{
|
||||
triggerConfig: triggerConfig{
|
||||
Table: "table",
|
||||
ParentType: "instance",
|
||||
InstanceIDColumn: "instance_id",
|
||||
ParentIDColumn: "parent_id",
|
||||
Resource: "resource",
|
||||
},
|
||||
templateName: countTriggerTmpl,
|
||||
},
|
||||
expects: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectExec(regexp.QuoteMeta(expCountTriggerQuery)).
|
||||
WithoutArgs().
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(1, 1),
|
||||
)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "count trigger",
|
||||
fields: fields{
|
||||
triggerConfig: triggerConfig{
|
||||
Table: "table",
|
||||
ParentType: "instance",
|
||||
InstanceIDColumn: "instance_id",
|
||||
ParentIDColumn: "parent_id",
|
||||
Resource: "resource",
|
||||
},
|
||||
templateName: deleteParentCountsTmpl,
|
||||
},
|
||||
expects: func(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectExec(regexp.QuoteMeta(expDeleteParentCountsQuery)).
|
||||
WithoutArgs().
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(1, 1),
|
||||
)
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := mock.ExpectationsWereMet()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
defer db.Close()
|
||||
tt.expects(mock)
|
||||
mock.ExpectClose()
|
||||
|
||||
m := &triggerMigration{
|
||||
db: &database.DB{
|
||||
DB: db,
|
||||
},
|
||||
triggerConfig: tt.fields.triggerConfig,
|
||||
templateName: tt.fields.templateName,
|
||||
}
|
||||
err = m.Execute(context.Background(), nil)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_triggerConfig_Check(t *testing.T) {
|
||||
type fields struct {
|
||||
Table string
|
||||
ParentType string
|
||||
InstanceIDColumn string
|
||||
ParentIDColumn string
|
||||
Resource string
|
||||
}
|
||||
type args struct {
|
||||
lastRun map[string]any
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "should",
|
||||
fields: fields{
|
||||
Table: "users2",
|
||||
ParentType: "instance",
|
||||
InstanceIDColumn: "instance_id",
|
||||
ParentIDColumn: "parent_id",
|
||||
Resource: "user",
|
||||
},
|
||||
args: args{
|
||||
lastRun: map[string]any{
|
||||
"table": "users1",
|
||||
"parent_type": "instance",
|
||||
"instance_id_column": "instance_id",
|
||||
"parent_id_column": "parent_id",
|
||||
"resource": "user",
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "should not",
|
||||
fields: fields{
|
||||
Table: "users1",
|
||||
ParentType: "instance",
|
||||
InstanceIDColumn: "instance_id",
|
||||
ParentIDColumn: "parent_id",
|
||||
Resource: "user",
|
||||
},
|
||||
args: args{
|
||||
lastRun: map[string]any{
|
||||
"table": "users1",
|
||||
"parent_type": "instance",
|
||||
"instance_id_column": "instance_id",
|
||||
"parent_id_column": "parent_id",
|
||||
"resource": "user",
|
||||
},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &triggerConfig{
|
||||
Table: tt.fields.Table,
|
||||
ParentType: tt.fields.ParentType,
|
||||
InstanceIDColumn: tt.fields.InstanceIDColumn,
|
||||
ParentIDColumn: tt.fields.ParentIDColumn,
|
||||
Resource: tt.fields.Resource,
|
||||
}
|
||||
got := c.Check(tt.args.lastRun)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user