diff --git a/cmd/setup/config_change.go b/cmd/setup/config_change.go new file mode 100644 index 0000000000..d996b9d4d7 --- /dev/null +++ b/cmd/setup/config_change.go @@ -0,0 +1,60 @@ +package setup + +import ( + "context" + + "github.com/zitadel/zitadel/internal/command" + "github.com/zitadel/zitadel/internal/config/systemdefaults" + "github.com/zitadel/zitadel/internal/eventstore" +) + +type externalConfigChange struct { + es *eventstore.Eventstore + ExternalDomain string `json:"externalDomain"` + ExternalSecure bool `json:"externalSecure"` + ExternalPort uint16 `json:"externalPort"` + + currentExternalDomain string + currentExternalSecure bool + currentExternalPort uint16 +} + +func (mig *externalConfigChange) SetLastExecution(lastRun map[string]interface{}) { + mig.currentExternalDomain, _ = lastRun["externalDomain"].(string) + externalPort, _ := lastRun["externalPort"].(float64) + mig.currentExternalPort = uint16(externalPort) + mig.currentExternalSecure, _ = lastRun["externalSecure"].(bool) +} + +func (mig *externalConfigChange) Check() bool { + return mig.currentExternalSecure != mig.ExternalSecure || + mig.currentExternalPort != mig.ExternalPort || + mig.currentExternalDomain != mig.ExternalDomain +} + +func (mig *externalConfigChange) Execute(ctx context.Context) error { + cmd, err := command.StartCommands(mig.es, + systemdefaults.SystemDefaults{}, + nil, + nil, + nil, + mig.ExternalDomain, + mig.ExternalSecure, + mig.ExternalPort, + nil, + nil, + nil, + nil, + nil, + nil, + nil) + + if err != nil { + return err + } + return cmd.ChangeSystemConfig(ctx, mig.currentExternalDomain, mig.currentExternalPort, mig.currentExternalSecure) +} + +func (mig *externalConfigChange) String() string { + return "config_change" +} diff --git a/cmd/setup/setup.go b/cmd/setup/setup.go index 5a92329230..e562d28e93 100644 --- a/cmd/setup/setup.go +++ b/cmd/setup/setup.go @@ -54,6 +54,8 @@ func Flags(cmd *cobra.Command) { } func Setup(config *Config, steps *Steps, masterKey string) { + logging.Info("setup started") + dbClient, err := database.Connect(config.Database) logging.OnError(err).Fatal("unable to connect to database") @@ -76,6 +78,15 @@ func Setup(config *Config, steps *Steps, masterKey string) { steps.S3DefaultInstance.externalSecure = config.ExternalSecure steps.S3DefaultInstance.externalPort = config.ExternalPort + repeatableSteps := []migration.RepeatableMigration{ + &externalConfigChange{ + es: eventstoreClient, + ExternalDomain: config.ExternalDomain, + ExternalPort: config.ExternalPort, + ExternalSecure: config.ExternalSecure, + }, + } + ctx := context.Background() err = migration.Migrate(ctx, eventstoreClient, steps.s1ProjectionTable) logging.OnError(err).Fatal("unable to migrate step 1") @@ -83,14 +94,9 @@ func Setup(config *Config, steps *Steps, masterKey string) { logging.OnError(err).Fatal("unable to migrate step 2") err = migration.Migrate(ctx, eventstoreClient, steps.S3DefaultInstance) logging.OnError(err).Fatal("unable to migrate step 3") -} -func initSteps(v *viper.Viper, files ...string) func() { - return func() { - for _, file := range files { - v.SetConfigFile(file) - err := v.MergeInConfig() - logging.WithFields("file", file).OnError(err).Warn("unable to read setup file") - } + for _, repeatableStep := range repeatableSteps { + err = migration.Migrate(ctx, eventstoreClient, repeatableStep) + logging.OnError(err).Fatalf("unable to migrate repeatable step: %s", repeatableStep.String()) } } diff --git a/internal/command/instance.go b/internal/command/instance.go index 4db7158d6f..f99e68d12b 100644 --- a/internal/command/instance.go +++ b/internal/command/instance.go @@ -398,6 +398,41 @@ func (c *Commands) SetDefaultOrg(ctx context.Context, orgID string) (*domain.Obj }, nil } +func (c *Commands) ChangeSystemConfig(ctx context.Context, externalDomain string, externalPort uint16, externalSecure bool) error { + validations, err := c.prepareChangeSystemConfig(externalDomain, externalPort, externalSecure)(ctx, c.eventstore.Filter) + if err != nil { + return err + } + for instanceID, instanceValidations := range validations { + if len(instanceValidations.Validations) == 0 { + continue + } + ctx := authz.WithConsole(authz.WithInstanceID(ctx, instanceID), instanceValidations.ProjectID, instanceValidations.ConsoleAppID) + cmds, err := preparation.PrepareCommands(ctx, c.eventstore.Filter, instanceValidations.Validations...) + if err != nil { + return err + } + _, err = c.eventstore.Push(ctx, cmds...) + if err != nil { + return err + } + } + return nil +} + +func (c *Commands) prepareChangeSystemConfig(externalDomain string, externalPort uint16, externalSecure bool) func(ctx context.Context, filter preparation.FilterToQueryReducer) (map[string]*SystemConfigChangesValidation, error) { + return func(ctx context.Context, filter preparation.FilterToQueryReducer) (map[string]*SystemConfigChangesValidation, error) { + if externalDomain == "" || externalPort == 0 { + return nil, nil + } + writeModel, err := getSystemConfigWriteModel(ctx, filter, externalDomain, c.externalDomain, externalPort, c.externalPort, externalSecure, c.externalSecure) + if err != nil { + return nil, err + } + return writeModel.NewChangedEvents(c), nil + } +} + func prepareAddInstance(a *instance.Aggregate, instanceName string, defaultLanguage language.Tag) preparation.Validation { return func() (preparation.CreateCommands, error) { return func(ctx context.Context, filter preparation.FilterToQueryReducer) ([]eventstore.Command, error) { @@ -494,3 +529,17 @@ func getInstanceWriteModel(ctx context.Context, filter preparation.FilterToQuery err = writeModel.Reduce() return writeModel, err } + +func getSystemConfigWriteModel(ctx context.Context, filter preparation.FilterToQueryReducer, externalDomain, newExternalDomain string, externalPort, newExternalPort uint16, externalSecure, newExternalSecure bool) (*SystemConfigWriteModel, error) { + writeModel := NewSystemConfigWriteModel(externalDomain, newExternalDomain, externalPort, newExternalPort, externalSecure, newExternalSecure) + events, err := filter(ctx, writeModel.Query()) + if err != nil { + return nil, err + } + if len(events) == 0 { + return writeModel, nil + } + writeModel.AppendEvents(events...) + err = writeModel.Reduce() + return writeModel, err +} diff --git a/internal/command/instance_domain.go b/internal/command/instance_domain.go index 80fd2e53cd..a0db372c1e 100644 --- a/internal/command/instance_domain.go +++ b/internal/command/instance_domain.go @@ -95,33 +95,68 @@ func (c *Commands) addInstanceDomain(a *instance.Aggregate, instanceDomain strin events := []eventstore.Command{ instance.NewDomainAddedEvent(ctx, &a.Aggregate, instanceDomain, generated), } - appWriteModel, err := getOIDCAppWriteModel(ctx, filter, authz.GetInstance(ctx).ProjectID(), authz.GetInstance(ctx).ConsoleApplicationID(), "") + consoleChangeEvent, err := c.updateConsoleRedirectURIs(ctx, filter, instanceDomain) if err != nil { return nil, err } - if appWriteModel.State.Exists() { - redirectUrls := append(appWriteModel.RedirectUris, http.BuildHTTP(instanceDomain, c.externalPort, c.externalSecure)+consoleRedirectPath) - logoutUrls := append(appWriteModel.PostLogoutRedirectUris, http.BuildHTTP(instanceDomain, c.externalPort, c.externalSecure)+consolePostLogoutPath) - consoleChangeEvent, err := project.NewOIDCConfigChangedEvent( - ctx, - ProjectAggregateFromWriteModel(&appWriteModel.WriteModel), - appWriteModel.AppID, - []project.OIDCConfigChanges{ - project.ChangeRedirectURIs(redirectUrls), - project.ChangePostLogoutRedirectURIs(logoutUrls), - }, - ) - if err != nil { - return nil, err - } - events = append(events, consoleChangeEvent) - } - - return events, nil + return append(events, consoleChangeEvent), nil }, nil } } +func (c *Commands) prepareUpdateConsoleRedirectURIs(instanceDomain string) preparation.Validation { + return func() (preparation.CreateCommands, error) { + if instanceDomain = strings.TrimSpace(instanceDomain); instanceDomain == "" { + return nil, errors.ThrowInvalidArgument(nil, "INST-E3j3s", "Errors.Invalid.Argument") + } + return func(ctx context.Context, filter preparation.FilterToQueryReducer) ([]eventstore.Command, error) { + consoleChangeEvent, err := c.updateConsoleRedirectURIs(ctx, filter, instanceDomain) + if err != nil { + return nil, err + } + return []eventstore.Command{ + consoleChangeEvent, + }, nil + }, nil + } +} + +func (c *Commands) updateConsoleRedirectURIs(ctx context.Context, filter preparation.FilterToQueryReducer, instanceDomain string) (*project.OIDCConfigChangedEvent, error) { + appWriteModel, err := getOIDCAppWriteModel(ctx, filter, authz.GetInstance(ctx).ProjectID(), authz.GetInstance(ctx).ConsoleApplicationID(), "") + if err != nil { + return nil, err + } + if !appWriteModel.State.Exists() { + return nil, nil + } + redirectURI := http.BuildHTTP(instanceDomain, c.externalPort, c.externalSecure) + consoleRedirectPath + changes := make([]project.OIDCConfigChanges, 0, 2) + if !containsURI(appWriteModel.RedirectUris, redirectURI) { + changes = append(changes, project.ChangeRedirectURIs(append(appWriteModel.RedirectUris, redirectURI))) + } + postLogoutRedirectURI := http.BuildHTTP(instanceDomain, c.externalPort, c.externalSecure) + consolePostLogoutPath + if !containsURI(appWriteModel.PostLogoutRedirectUris, postLogoutRedirectURI) { + changes = append(changes, project.ChangePostLogoutRedirectURIs(append(appWriteModel.PostLogoutRedirectUris, postLogoutRedirectURI))) + } + return project.NewOIDCConfigChangedEvent( + ctx, + ProjectAggregateFromWriteModel(&appWriteModel.WriteModel), + appWriteModel.AppID, + changes, + ) +} + +//checkUpdateConsoleRedirectURIs validates if the required console uri is present in the redirect_uris and post_logout_redirect_uris +//it will return true only if present in both list, otherwise false +func (c *Commands) checkUpdateConsoleRedirectURIs(instanceDomain string, redirectURIs, postLogoutRedirectURIs []string) bool { + redirectURI := http.BuildHTTP(instanceDomain, c.externalPort, c.externalSecure) + consoleRedirectPath + if !containsURI(redirectURIs, redirectURI) { + return false + } + postLogoutRedirectURI := http.BuildHTTP(instanceDomain, c.externalPort, c.externalSecure) + consolePostLogoutPath + return containsURI(postLogoutRedirectURIs, postLogoutRedirectURI) +} + func setPrimaryInstanceDomain(a *instance.Aggregate, instanceDomain string) preparation.Validation { return func() (preparation.CreateCommands, error) { if instanceDomain = strings.TrimSpace(instanceDomain); instanceDomain == "" { @@ -174,3 +209,12 @@ func getInstanceDomainWriteModel(ctx context.Context, filter preparation.FilterT err = domainWriteModel.Reduce() return domainWriteModel, err } + +func containsURI(uris []string, uri string) bool { + for _, u := range uris { + if u == uri { + return true + } + } + return false +} diff --git a/internal/command/system_model.go b/internal/command/system_model.go new file mode 100644 index 0000000000..9ed0684f6b --- /dev/null +++ b/internal/command/system_model.go @@ -0,0 +1,184 @@ +package command + +import ( + "strings" + + "github.com/zitadel/zitadel/internal/command/preparation" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/repository/instance" + "github.com/zitadel/zitadel/internal/repository/project" +) + +type SystemConfigWriteModel struct { + eventstore.WriteModel + + Instances map[string]*systemConfigChangesInstanceModel + externalDomain string + externalPort uint16 + externalSecure bool + newExternalDomain string + newExternalPort uint16 + newExternalSecure bool +} + +type systemConfigChangesInstanceModel struct { + Domains []string + GeneratedDomain string + ProjectID string + ConsoleAppID string + RedirectUris []string + PostLogoutRedirectUris []string +} + +func NewSystemConfigWriteModel(externalDomain, newExternalDomain string, externalPort, newExternalPort uint16, externalSecure, newExternalSecure bool) *SystemConfigWriteModel { + return &SystemConfigWriteModel{ + WriteModel: eventstore.WriteModel{}, + Instances: make(map[string]*systemConfigChangesInstanceModel), + externalDomain: externalDomain, + externalPort: externalPort, + externalSecure: externalSecure, + newExternalDomain: newExternalDomain, + newExternalPort: newExternalPort, + newExternalSecure: newExternalSecure, + } +} + +func (wm *SystemConfigWriteModel) Reduce() error { + for _, event := range wm.Events { + switch e := event.(type) { + case *instance.InstanceAddedEvent: + wm.Instances[e.Aggregate().InstanceID] = &systemConfigChangesInstanceModel{} + case *instance.InstanceRemovedEvent: + delete(wm.Instances, e.Aggregate().InstanceID) + case *instance.DomainAddedEvent: + if !e.Generated && e.Domain != wm.externalDomain && e.Domain != wm.newExternalDomain { + continue + } + if e.Generated && !strings.HasSuffix(e.Domain, wm.externalDomain) && !strings.HasSuffix(e.Domain, wm.newExternalDomain) { + continue + } + wm.Instances[e.Aggregate().InstanceID].Domains = append(wm.Instances[e.Aggregate().InstanceID].Domains, e.Domain) + case *instance.DomainRemovedEvent: + domains := wm.Instances[e.Aggregate().InstanceID].Domains + for i, domain := range domains { + if domain == e.Domain { + domains[i] = domains[len(domains)-1] + domains[len(domains)-1] = "" + wm.Instances[e.Aggregate().InstanceID].Domains = domains[:len(domains)-1] + break + } + } + case *instance.ProjectSetEvent: + wm.Instances[e.Aggregate().InstanceID].ProjectID = e.ProjectID + case *instance.ConsoleSetEvent: + wm.Instances[e.Aggregate().InstanceID].ConsoleAppID = e.AppID + case *project.OIDCConfigAddedEvent: + if wm.Instances[e.Aggregate().InstanceID].ConsoleAppID != e.AppID { + continue + } + wm.Instances[e.Aggregate().InstanceID].RedirectUris = e.RedirectUris + wm.Instances[e.Aggregate().InstanceID].PostLogoutRedirectUris = e.PostLogoutRedirectUris + case *project.OIDCConfigChangedEvent: + if wm.Instances[e.Aggregate().InstanceID].ConsoleAppID != e.AppID { + continue + } + if e.RedirectUris != nil { + wm.Instances[e.Aggregate().InstanceID].RedirectUris = *e.RedirectUris + } + if e.PostLogoutRedirectUris != nil { + wm.Instances[e.Aggregate().InstanceID].PostLogoutRedirectUris = *e.PostLogoutRedirectUris + } + } + } + return nil +} + +func (wm *SystemConfigWriteModel) Query() *eventstore.SearchQueryBuilder { + return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). + AddQuery(). + AggregateTypes(instance.AggregateType). + EventTypes( + instance.InstanceAddedEventType, + instance.InstanceRemovedEventType, + instance.InstanceDomainAddedEventType, + instance.InstanceDomainRemovedEventType, + instance.ProjectSetEventType, + instance.ConsoleSetEventType, + ). + Or(). + AggregateTypes(project.AggregateType). + EventTypes( + project.OIDCConfigAddedType, + project.OIDCConfigChangedType, + ). + Builder() +} + +type SystemConfigChangesValidation struct { + ProjectID string + ConsoleAppID string + Validations []preparation.Validation + InstanceID string +} + +func (wm *SystemConfigWriteModel) NewChangedEvents(commands *Commands) map[string]*SystemConfigChangesValidation { + var newCustomDomainExists, isInstanceOfCustomDomain bool + var instanceOfCustomDomain string + cmds := make(map[string]*SystemConfigChangesValidation) + for i, inst := range wm.Instances { + cmds[i] = &SystemConfigChangesValidation{ + InstanceID: i, + ProjectID: inst.ProjectID, + ConsoleAppID: inst.ConsoleAppID, + } + //check each instance separately for changes (using the generated domain) and check if there's an existing custom domain + newCustomDomainExists, isInstanceOfCustomDomain = wm.changeConfig(cmds[i], inst, commands) + if isInstanceOfCustomDomain || newCustomDomainExists { + instanceOfCustomDomain = i + } + } + //handle the custom domain at last + if newCustomDomainExists { + //if the domain itself already exists, then only check if the uris of the console app exist as well + wm.changeURIs(cmds[instanceOfCustomDomain], wm.Instances[instanceOfCustomDomain], commands, wm.newExternalDomain) + return cmds + } + //otherwise the add instance domain will take care of the uris + cmds[instanceOfCustomDomain].Validations = append(cmds[instanceOfCustomDomain].Validations, commands.addInstanceDomain(instance.NewAggregate(instanceOfCustomDomain), wm.newExternalDomain, false)) + return cmds +} + +func (wm *SystemConfigWriteModel) changeConfig(validation *SystemConfigChangesValidation, inst *systemConfigChangesInstanceModel, commands *Commands) (newCustomDomainExists, isInstanceOfCustomDomain bool) { + var newGeneratedDomain string + var newGeneratedDomainExists bool + for _, domain := range inst.Domains { + if domain == wm.newExternalDomain { + newCustomDomainExists = true + continue + } + if domain != wm.newExternalDomain && strings.HasSuffix(domain, wm.newExternalDomain) { + newGeneratedDomainExists = true + } + if !newCustomDomainExists && domain == wm.externalDomain { + isInstanceOfCustomDomain = true + } + if domain != wm.externalDomain && strings.HasSuffix(domain, wm.externalDomain) { + newGeneratedDomain = strings.TrimSuffix(domain, wm.externalDomain) + wm.newExternalDomain + } + } + if newGeneratedDomainExists { + //if the domain itself already exists, then only check if the uris of the console app exist as well + wm.changeURIs(validation, inst, commands, newGeneratedDomain) + return newCustomDomainExists, isInstanceOfCustomDomain + } + //otherwise the add instance domain will take care of the uris + validation.Validations = append(validation.Validations, commands.addInstanceDomain(instance.NewAggregate(validation.InstanceID), newGeneratedDomain, true)) + return newCustomDomainExists, isInstanceOfCustomDomain +} + +func (wm *SystemConfigWriteModel) changeURIs(validation *SystemConfigChangesValidation, inst *systemConfigChangesInstanceModel, commands *Commands, domain string) { + if commands.checkUpdateConsoleRedirectURIs(domain, inst.RedirectUris, inst.PostLogoutRedirectUris) { + return + } + validation.Validations = append(validation.Validations, commands.prepareUpdateConsoleRedirectURIs(domain)) +} diff --git a/internal/migration/command.go b/internal/migration/command.go index 16ac4eb188..5e43b35b37 100644 --- a/internal/migration/command.go +++ b/internal/migration/command.go @@ -15,20 +15,23 @@ import ( type SetupStep struct { eventstore.BaseEvent `json:"-"` migration Migration - Name string `json:"name"` - Error error `json:"error,omitempty"` + Name string `json:"name"` + Error error `json:"error,omitempty"` + LastRun interface{} `json:"lastRun,omitempty"` } func (s *SetupStep) UnmarshalJSON(data []byte) error { fields := struct { - Name string `json:"name,"` - Error *errors.CaosError `json:"error"` + Name string `json:"name,"` + Error *errors.CaosError `json:"error"` + LastRun map[string]interface{} `json:"lastRun,omitempty"` }{} if err := json.Unmarshal(data, &fields); err != nil { return err } s.Name = fields.Name s.Error = fields.Error + s.LastRun = fields.LastRun return nil } @@ -46,15 +49,21 @@ func setupStartedCmd(migration Migration) eventstore.Command { func setupDoneCmd(migration Migration, err error) eventstore.Command { ctx := authz.SetCtxData(service.WithService(context.Background(), "system"), authz.CtxData{UserID: "system", OrgID: "SYSTEM", ResourceOwner: "SYSTEM"}) + typ := doneType + var lastRun interface{} + if repeatable, ok := migration.(RepeatableMigration); ok { + typ = repeatableDoneType + lastRun = repeatable + } + if err != nil { + typ = failedType + } + s := &SetupStep{ migration: migration, Name: migration.String(), Error: err, - } - - typ := doneType - if err != nil { - typ = failedType + LastRun: lastRun, } s.BaseEvent = *eventstore.NewBaseEventForPush( @@ -75,7 +84,8 @@ func (s *SetupStep) UniqueConstraints() []*eventstore.EventUniqueConstraint { return []*eventstore.EventUniqueConstraint{ eventstore.NewAddGlobalEventUniqueConstraint("migration_started", s.migration.String(), "Errors.Step.Started.AlreadyExists"), } - case failedType: + case failedType, + repeatableDoneType: return []*eventstore.EventUniqueConstraint{ eventstore.NewRemoveGlobalEventUniqueConstraint("migration_started", s.migration.String()), } @@ -90,6 +100,7 @@ func RegisterMappers(es *eventstore.Eventstore) { es.RegisterFilterEventMapper(startedType, SetupMapper) es.RegisterFilterEventMapper(doneType, SetupMapper) es.RegisterFilterEventMapper(failedType, SetupMapper) + es.RegisterFilterEventMapper(repeatableDoneType, SetupMapper) } func SetupMapper(event *repository.Event) (eventstore.Event, error) { diff --git a/internal/migration/migration.go b/internal/migration/migration.go index 0dc5d14564..2f27a3715c 100644 --- a/internal/migration/migration.go +++ b/internal/migration/migration.go @@ -10,11 +10,12 @@ import ( ) const ( - startedType = eventstore.EventType("system.migration.started") - doneType = eventstore.EventType("system.migration.done") - failedType = eventstore.EventType("system.migration.failed") - aggregateType = eventstore.AggregateType("system") - aggregateID = "SYSTEM" + startedType = eventstore.EventType("system.migration.started") + doneType = eventstore.EventType("system.migration.done") + failedType = eventstore.EventType("system.migration.failed") + repeatableDoneType = eventstore.EventType("system.migration.repeatable.done") + aggregateType = eventstore.AggregateType("system") + aggregateID = "SYSTEM" ) type Migration interface { @@ -22,7 +23,15 @@ type Migration interface { Execute(context.Context) error } +type RepeatableMigration interface { + Migration + SetLastExecution(lastRun map[string]interface{}) + Check() bool +} + func Migrate(ctx context.Context, es *eventstore.Eventstore, migration Migration) (err error) { + logging.Infof("verify migration %s", migration.String()) + if should, err := shouldExec(ctx, es, migration); !should || err != nil { return err } @@ -31,6 +40,7 @@ func Migrate(ctx context.Context, es *eventstore.Eventstore, migration Migration return err } + logging.Infof("starting migration %s", migration.String()) err = migration.Execute(ctx) logging.OnError(err).Error("migration failed") @@ -48,7 +58,7 @@ func shouldExec(ctx context.Context, es *eventstore.Eventstore, migration Migrat AddQuery(). AggregateTypes(aggregateType). AggregateIDs(aggregateID). - EventTypes(startedType, doneType, failedType). + EventTypes(startedType, doneType, repeatableDoneType, failedType). Builder()) if err != nil { return false, err @@ -68,10 +78,23 @@ func shouldExec(ctx context.Context, es *eventstore.Eventstore, migration Migrat switch event.Type() { case startedType, failedType: isStarted = !isStarted - case doneType: - return false, nil + case doneType, + repeatableDoneType: + repeatable, ok := migration.(RepeatableMigration) + if !ok { + return false, nil + } + isStarted = false + repeatable.SetLastExecution(e.LastRun.(map[string]interface{})) } } - return !isStarted, nil + if isStarted { + return false, nil + } + repeatable, ok := migration.(RepeatableMigration) + if !ok { + return true, nil + } + return repeatable.Check(), nil }