From 42a5c6147ef12d876965ebe91bf82f5e1a0d8ee9 Mon Sep 17 00:00:00 2001 From: Elio Bischof Date: Thu, 26 Sep 2024 10:31:43 +0200 Subject: [PATCH] mirror error handling --- cmd/initialise/init.go | 5 +- cmd/mirror/auth.go | 32 ++++-- cmd/mirror/event_store.go | 122 +++++++++++++--------- cmd/mirror/mirror.go | 62 +++++++---- cmd/mirror/projections.go | 66 +++++++----- cmd/mirror/system.go | 62 ++++++++--- cmd/mirror/verify.go | 43 +++++--- cmd/ready/ready.go | 8 +- cmd/setup/setup.go | 4 +- cmd/start/start.go | 12 +-- internal/{socket => unixsocket}/listen.go | 2 +- internal/{socket => unixsocket}/socket.go | 2 +- 12 files changed, 272 insertions(+), 148 deletions(-) rename internal/{socket => unixsocket}/listen.go (98%) rename internal/{socket => unixsocket}/socket.go (98%) diff --git a/cmd/initialise/init.go b/cmd/initialise/init.go index 5e2dccc098..1aa31ede75 100644 --- a/cmd/initialise/init.go +++ b/cmd/initialise/init.go @@ -6,10 +6,9 @@ import ( "github.com/spf13/cobra" "github.com/spf13/viper" "github.com/zitadel/logging" - "github.com/zitadel/zitadel/internal/socket" - "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/database/dialect" + "github.com/zitadel/zitadel/internal/unixsocket" ) var ( @@ -72,7 +71,7 @@ func InitAll(ctx context.Context, config *Config) { } func initialise(config database.Config, steps ...func(*database.DB) error) error { - closeSocket, err := socket.ListenAndIgnore() + closeSocket, err := unixsocket.ListenAndIgnore() logging.OnError(err).Fatal("unable to listen on socket") defer closeSocket() diff --git a/cmd/mirror/auth.go b/cmd/mirror/auth.go index df94708e71..9870ab9af0 100644 --- a/cmd/mirror/auth.go +++ b/cmd/mirror/auth.go @@ -3,6 +3,7 @@ package mirror import ( "context" _ "embed" + "fmt" "io" "time" @@ -33,23 +34,29 @@ Only auth requests are mirrored`, return cmd } -func copyAuth(ctx context.Context, config *Migration) { +func copyAuth(ctx context.Context, config *Migration) error { sourceClient, err := database.Connect(config.Source, false, dialect.DBPurposeQuery) - logging.OnError(err).Fatal("unable to connect to source database") + if err != nil { + return fmt.Errorf("unable to connect to source database: %w", err) + } defer sourceClient.Close() destClient, err := database.Connect(config.Destination, false, dialect.DBPurposeEventPusher) - logging.OnError(err).Fatal("unable to connect to destination database") + if err != nil { + return fmt.Errorf("unable to connect to destination database: %w", err) + } defer destClient.Close() - copyAuthRequests(ctx, sourceClient, destClient) + return copyAuthRequests(ctx, sourceClient, destClient) } -func copyAuthRequests(ctx context.Context, source, dest *database.DB) { +func copyAuthRequests(ctx context.Context, source, dest *database.DB) error { start := time.Now() sourceConn, err := source.Conn(ctx) - logging.OnError(err).Fatal("unable to acquire connection") + if err != nil { + return fmt.Errorf("unable to acquire connection: %w", err) + } defer sourceConn.Close() r, w := io.Pipe() @@ -66,7 +73,9 @@ func copyAuthRequests(ctx context.Context, source, dest *database.DB) { }() destConn, err := dest.Conn(ctx) - logging.OnError(err).Fatal("unable to acquire connection") + if err != nil { + return fmt.Errorf("unable to acquire connection: %w", err) + } defer destConn.Close() var affected int64 @@ -85,7 +94,12 @@ func copyAuthRequests(ctx context.Context, source, dest *database.DB) { return err }) - logging.OnError(err).Fatal("unable to copy auth requests to destination") - logging.OnError(<-errs).Fatal("unable to copy auth requests from source") + if err != nil { + return fmt.Errorf("unable to copy auth requests to destination: %w", err) + } + if err = <-errs; err != nil { + return fmt.Errorf("unable to copy auth requests from source: %w", err) + } logging.WithFields("took", time.Since(start), "count", affected).Info("auth requests migrated") + return nil } diff --git a/cmd/mirror/event_store.go b/cmd/mirror/event_store.go index 2eab4eb0da..7067c04e49 100644 --- a/cmd/mirror/event_store.go +++ b/cmd/mirror/event_store.go @@ -5,6 +5,7 @@ import ( "database/sql" _ "embed" "errors" + "fmt" "io" "time" @@ -32,9 +33,9 @@ func eventstoreCmd() *cobra.Command { Long: `mirrors the eventstore of an instance from one database to another ZITADEL needs to be initialized and set up with the --for-mirror flag Migrate only copies events2 and unique constraints`, - Run: func(cmd *cobra.Command, args []string) { + RunE: func(cmd *cobra.Command, args []string) error { config := mustNewMigrationConfig(viper.GetViper()) - copyEventstore(cmd.Context(), config) + return copyEventstore(cmd.Context(), config) }, } @@ -44,17 +45,26 @@ Migrate only copies events2 and unique constraints`, return cmd } -func copyEventstore(ctx context.Context, config *Migration) { +func copyEventstore(ctx context.Context, config *Migration) error { sourceClient, err := db.Connect(config.Source, false, dialect.DBPurposeEventPusher) - logging.OnError(err).Fatal("unable to connect to source database") + if err != nil { + return fmt.Errorf("unable to connect to source database: %w", err) + } defer sourceClient.Close() destClient, err := db.Connect(config.Destination, false, dialect.DBPurposeEventPusher) - logging.OnError(err).Fatal("unable to connect to destination database") + if err != nil { + return fmt.Errorf("unable to connect to destination database: %w", err) + } defer destClient.Close() - copyEvents(ctx, sourceClient, destClient, config.EventBulkSize) - copyUniqueConstraints(ctx, sourceClient, destClient) + if err = copyEvents(ctx, sourceClient, destClient, config.EventBulkSize); err != nil { + return fmt.Errorf("unable to copy events: %w", err) + } + if err = copyUniqueConstraints(ctx, sourceClient, destClient); err != nil { + return fmt.Errorf("unable to copy unique constraints: %w", err) + } + return nil } func positionQuery(db *db.DB) string { @@ -69,19 +79,21 @@ func positionQuery(db *db.DB) string { } } -func copyEvents(ctx context.Context, source, dest *db.DB, bulkSize uint32) { +func copyEvents(ctx context.Context, source, dest *db.DB, bulkSize uint32) error { start := time.Now() reader, writer := io.Pipe() - migrationID, err := id.SonyFlakeGenerator().Next() - logging.OnError(err).Fatal("unable to generate migration id") - + if err != nil { + return fmt.Errorf("unable to generate migration id: %w", err) + } sourceConn, err := source.Conn(ctx) - logging.OnError(err).Fatal("unable to acquire source connection") - + if err != nil { + return fmt.Errorf("unable to acquire source connection: %w", err) + } destConn, err := dest.Conn(ctx) - logging.OnError(err).Fatal("unable to acquire dest connection") - + if err != nil { + return fmt.Errorf("unable to acquire dest connection: %w", err) + } sourceES := eventstore.NewEventstoreFromOne(postgres.New(source, &postgres.Config{ MaxRetries: 3, })) @@ -90,11 +102,14 @@ func copyEvents(ctx context.Context, source, dest *db.DB, bulkSize uint32) { })) previousMigration, err := queryLastSuccessfulMigration(ctx, destinationES, source.DatabaseName()) - logging.OnError(err).Fatal("unable to query latest successful migration") - + if err != nil { + return fmt.Errorf("unable to query latest successful migration: %w", err) + } maxPosition, err := writeMigrationStart(ctx, sourceES, migrationID, dest.DatabaseName()) - logging.OnError(err).Fatal("unable to write migration started event") + if err != nil { + return fmt.Errorf("unable to write migration start: %w", err) + } logging.WithFields("from", previousMigration.Position, "to", maxPosition).Info("start event migration") nextPos := make(chan bool, 1) @@ -102,7 +117,7 @@ func copyEvents(ctx context.Context, source, dest *db.DB, bulkSize uint32) { errs := make(chan error, 3) go func() { - err := sourceConn.Raw(func(driverConn interface{}) error { + goErr := sourceConn.Raw(func(driverConn interface{}) error { conn := driverConn.(*stdlib.Conn).Conn() nextPos <- true var i uint32 @@ -139,7 +154,7 @@ func copyEvents(ctx context.Context, source, dest *db.DB, bulkSize uint32) { }) writer.Close() close(nextPos) - errs <- err + errs <- goErr }() // generate next position for @@ -147,41 +162,41 @@ func copyEvents(ctx context.Context, source, dest *db.DB, bulkSize uint32) { defer close(pos) for range nextPos { var position float64 - err := dest.QueryRowContext( + goErr := dest.QueryRowContext( ctx, func(row *sql.Row) error { return row.Scan(&position) }, positionQuery(dest), ) - if err != nil { + if goErr != nil { errs <- zerrors.ThrowUnknown(err, "MIGRA-kMyPH", "unable to query next position") return } pos <- position } }() - var eventCount int64 errs <- destConn.Raw(func(driverConn interface{}) error { conn := driverConn.(*stdlib.Conn).Conn() - tag, err := conn.PgConn().CopyFrom(ctx, reader, "COPY eventstore.events2 FROM STDIN") + tag, cbErr := conn.PgConn().CopyFrom(ctx, reader, "COPY eventstore.events2 FROM STDIN") eventCount = tag.RowsAffected() - if err != nil { - return zerrors.ThrowUnknown(err, "MIGRA-DTHi7", "unable to copy events into destination") + if cbErr != nil { + return zerrors.ThrowUnknown(cbErr, "MIGRA-DTHi7", "unable to copy events into destination") } return nil }) - close(errs) - writeCopyEventsDone(ctx, destinationES, migrationID, source.DatabaseName(), maxPosition, errs) - + if err = writeCopyEventsDone(ctx, destinationES, migrationID, source.DatabaseName(), maxPosition, errs); err != nil { + return fmt.Errorf("unable to write migration done: %w", err) + } logging.WithFields("took", time.Since(start), "count", eventCount).Info("events migrated") + return nil } -func writeCopyEventsDone(ctx context.Context, es *eventstore.EventStore, id, source string, position decimal.Decimal, errs <-chan error) { +func writeCopyEventsDone(ctx context.Context, es *eventstore.EventStore, id, source string, position decimal.Decimal, errs <-chan error) error { joinedErrs := make([]error, 0, len(errs)) for err := range errs { joinedErrs = append(joinedErrs, err) @@ -190,25 +205,31 @@ func writeCopyEventsDone(ctx context.Context, es *eventstore.EventStore, id, sou if err != nil { logging.WithError(err).Error("unable to mirror events") - err := writeMigrationFailed(ctx, es, id, source, err) - logging.OnError(err).Fatal("unable to write failed event") - return + err = writeMigrationFailed(ctx, es, id, source, err) + if err != nil { + return fmt.Errorf("unable to write failed event: %w", err) + } + return nil } - err = writeMigrationSucceeded(ctx, es, id, source, position) - logging.OnError(err).Fatal("unable to write failed event") + if err = writeMigrationSucceeded(ctx, es, id, source, position); err != nil { + return fmt.Errorf("unable to write succeeded event: %w", err) + } + return nil } -func copyUniqueConstraints(ctx context.Context, source, dest *db.DB) { +func copyUniqueConstraints(ctx context.Context, source, dest *db.DB) error { start := time.Now() reader, writer := io.Pipe() errs := make(chan error, 1) sourceConn, err := source.Conn(ctx) - logging.OnError(err).Fatal("unable to acquire source connection") + if err != nil { + return fmt.Errorf("unable to acquire source connection: %w", err) + } go func() { - err := sourceConn.Raw(func(driverConn interface{}) error { + errs <- sourceConn.Raw(func(driverConn interface{}) error { conn := driverConn.(*stdlib.Conn).Conn() var stmt database.Statement stmt.WriteString("COPY (SELECT instance_id, unique_type, unique_field FROM eventstore.unique_constraints ") @@ -219,11 +240,12 @@ func copyUniqueConstraints(ctx context.Context, source, dest *db.DB) { writer.Close() return err }) - errs <- err }() destConn, err := dest.Conn(ctx) - logging.OnError(err).Fatal("unable to acquire dest connection") + if err != nil { + return fmt.Errorf("unable to acquire dest connection: %w", err) + } var eventCount int64 err = destConn.Raw(func(driverConn interface{}) error { @@ -234,18 +256,22 @@ func copyUniqueConstraints(ctx context.Context, source, dest *db.DB) { stmt.WriteString("DELETE FROM eventstore.unique_constraints ") stmt.WriteString(instanceClause()) - _, err := conn.Exec(ctx, stmt.String()) - if err != nil { - return err + _, cbErr := conn.Exec(ctx, stmt.String()) + if cbErr != nil { + return cbErr } } - tag, err := conn.PgConn().CopyFrom(ctx, reader, "COPY eventstore.unique_constraints FROM stdin") + tag, cbErr := conn.PgConn().CopyFrom(ctx, reader, "COPY eventstore.unique_constraints FROM stdin") eventCount = tag.RowsAffected() - - return err + return cbErr }) - logging.OnError(err).Fatal("unable to copy unique constraints to destination") - logging.OnError(<-errs).Fatal("unable to copy unique constraints from source") + if err != nil { + return fmt.Errorf("unable to copy unique constraints to destination: %w", err) + } + if err = <-errs; err != nil { + return fmt.Errorf("unable to copy unique constraints from source: %w", err) + } logging.WithFields("took", time.Since(start), "count", eventCount).Info("unique constraints migrated") + return nil } diff --git a/cmd/mirror/mirror.go b/cmd/mirror/mirror.go index 3fbfe1ae94..6cf78bab65 100644 --- a/cmd/mirror/mirror.go +++ b/cmd/mirror/mirror.go @@ -3,12 +3,13 @@ package mirror import ( "bytes" _ "embed" + "errors" + "fmt" + "github.com/zitadel/zitadel/internal/unixsocket" "strings" "github.com/spf13/cobra" "github.com/spf13/viper" - "github.com/zitadel/logging" - "github.com/zitadel/zitadel/cmd/key" ) @@ -19,6 +20,7 @@ var ( ) func New(configFiles *[]string) *cobra.Command { + closeSocket := func() error { return nil } cmd := &cobra.Command{ Use: "mirror", Short: "mirrors all data of ZITADEL from one database to another", @@ -34,29 +36,49 @@ Order of execution: 3. mirror event store tables 4. recompute projections 5. verify`, - PersistentPreRun: func(cmd *cobra.Command, args []string) { - err := viper.MergeConfig(bytes.NewBuffer(defaultConfig)) - logging.OnError(err).Fatal("unable to read default config") + PersistentPreRunE: func(cmd *cobra.Command, args []string) (err error) { + closeSocket, err = unixsocket.ListenAndIgnore() + if err != nil { + return fmt.Errorf("unable to listen on socket: %w", err) + } + if err = viper.MergeConfig(bytes.NewBuffer(defaultConfig)); err != nil { + return errors.New("unable to read default config") + } for _, file := range *configFiles { viper.SetConfigFile(file) - err := viper.MergeInConfig() - logging.WithFields("file", file).OnError(err).Warn("unable to read config file") + if err = viper.MergeInConfig(); err != nil { + return fmt.Errorf("unable to read config file: %w", err) + } } + return nil }, - Run: func(cmd *cobra.Command, args []string) { + RunE: func(cmd *cobra.Command, args []string) error { + defer closeSocket() config := mustNewMigrationConfig(viper.GetViper()) projectionConfig := mustNewProjectionsConfig(viper.GetViper()) masterKey, err := key.MasterKey(cmd) - logging.OnError(err).Fatal("unable to read master key") + if err != nil { + return fmt.Errorf("unable to read master key: %w", err) + } + if err = copySystem(cmd.Context(), config); err != nil { + return fmt.Errorf("unable to copy system tables: %w", err) + } + if err = copyAuth(cmd.Context(), config); err != nil { + return fmt.Errorf("unable to copy auth tables: %w", err) + } + if err = copyEventstore(cmd.Context(), config); err != nil { + return fmt.Errorf("unable to copy eventstore tables: %w", err) + } - copySystem(cmd.Context(), config) - copyAuth(cmd.Context(), config) - copyEventstore(cmd.Context(), config) - - projections(cmd.Context(), projectionConfig, masterKey) - verifyMigration(cmd.Context(), config) + if err = projections(cmd.Context(), projectionConfig, masterKey); err != nil { + return fmt.Errorf("unable to recompute projections: %w", err) + } + if err = verifyMigration(cmd.Context(), config); err != nil { + return fmt.Errorf("unable to verify migration: %w", err) + } + return nil }, } @@ -70,11 +92,11 @@ The flag should be provided if you want to execute the mirror command multiple t migrateProjectionsFlags(cmd) cmd.AddCommand( - eventstoreCmd(), - systemCmd(), - projectionsCmd(), - authCmd(), - verifyCmd(), + eventstoreCmd(closeSocket), + systemCmd(closeSocket), + projectionsCmd(closeSocket), + authCmd(closeSocket), + verifyCmd(closeSocket), ) return cmd diff --git a/cmd/mirror/projections.go b/cmd/mirror/projections.go index af7ba98c5c..b4316b9d76 100644 --- a/cmd/mirror/projections.go +++ b/cmd/mirror/projections.go @@ -3,6 +3,7 @@ package mirror import ( "context" "database/sql" + "fmt" "net/http" "sync" "time" @@ -50,13 +51,13 @@ func projectionsCmd() *cobra.Command { cmd := &cobra.Command{ Use: "projections", Short: "calls the projections synchronously", - Run: func(cmd *cobra.Command, args []string) { + RunE: func(cmd *cobra.Command, args []string) error { config := mustNewProjectionsConfig(viper.GetViper()) - masterKey, err := key.MasterKey(cmd) - logging.OnError(err).Fatal("unable to read master key") - - projections(cmd.Context(), config, masterKey) + if err != nil { + return fmt.Errorf("unable to read master key: %w", err) + } + return projections(cmd.Context(), config, masterKey) }, } @@ -100,24 +101,30 @@ func projections( ctx context.Context, config *ProjectionsConfig, masterKey string, -) { +) error { start := time.Now() - client, err := database.Connect(config.Destination, false, dialect.DBPurposeQuery) - logging.OnError(err).Fatal("unable to connect to database") - + if err != nil { + return fmt.Errorf("unable to connect to database: %w", err) + } keyStorage, err := crypto_db.NewKeyStorage(client, masterKey) - logging.OnError(err).Fatal("cannot start key storage") - + if err != nil { + return fmt.Errorf("unable to start key storage: %w", err) + } keys, err := encryption.EnsureEncryptionKeys(ctx, config.EncryptionKeys, keyStorage) - logging.OnError(err).Fatal("unable to read encryption keys") - + if err != nil { + return fmt.Errorf("unable to read encryption keys: %w", err) + } staticStorage, err := config.AssetStorage.NewStorage(client.DB) - logging.OnError(err).Fatal("unable create static storage") + if err != nil { + return fmt.Errorf("unable create static storage: %w", err) + } config.Eventstore.Querier = old_es.NewCRDB(client) esPusherDBClient, err := database.Connect(config.Destination, false, dialect.DBPurposeEventPusher) - logging.OnError(err).Fatal("unable to connect eventstore push client") + if err != nil { + return fmt.Errorf("unable to connect eventstore push client: %w", err) + } config.Eventstore.Pusher = new_es.NewEventstore(esPusherDBClient) es := eventstore.NewEventstore(config.Eventstore) esV4 := es_v4.NewEventstoreFromOne(es_v4_pg.New(client, &es_v4_pg.Config{ @@ -149,11 +156,14 @@ func projections( config.SystemAPIUsers, false, ) - logging.OnError(err).Fatal("unable to start queries") + if err != nil { + return fmt.Errorf("unable to start queries: %w", err) + } authZRepo, err := authz.Start(queries, es, client, keys.OIDC, config.ExternalSecure) - logging.OnError(err).Fatal("unable to start authz repo") - + if err != nil { + return fmt.Errorf("unable to start authz repo: %w", err) + } webAuthNConfig := &webauthn.Config{ DisplayName: config.WebAuthNName, ExternalSecure: config.ExternalSecure, @@ -185,13 +195,14 @@ func projections( config.OIDC.DefaultRefreshTokenIdleExpiration, config.DefaultInstance.SecretGenerators, ) - logging.OnError(err).Fatal("unable to start commands") - + if err != nil { + return fmt.Errorf("unable to start commands: %w", err) + } err = projection.Create(ctx, client, es, config.Projections, keys.OIDC, keys.SAML, config.SystemAPIUsers) - logging.OnError(err).Fatal("unable to start projections") - + if err != nil { + return fmt.Errorf("unable to create projections: %w", err) + } i18n.MustLoadSupportedLanguagesFromDir() - notification.Register( ctx, config.Projections.Customizations["notifications"], @@ -214,13 +225,17 @@ func projections( config.Auth.Spooler.Client = client config.Auth.Spooler.Eventstore = es authView, err := auth_view.StartView(config.Auth.Spooler.Client, keys.OIDC, queries, config.Auth.Spooler.Eventstore) - logging.OnError(err).Fatal("unable to start auth view") + if err != nil { + return fmt.Errorf("unable to start auth view: %w", err) + } auth_handler.Register(ctx, config.Auth.Spooler, authView, queries) config.Admin.Spooler.Client = client config.Admin.Spooler.Eventstore = es adminView, err := admin_view.StartView(config.Admin.Spooler.Client) - logging.OnError(err).Fatal("unable to start admin view") + if err != nil { + return fmt.Errorf("unable to start admin view: %w", err) + } admin_handler.Register(ctx, config.Admin.Spooler, adminView, staticStorage) @@ -248,6 +263,7 @@ func projections( close(failedInstances) logging.WithFields("took", time.Since(start)).Info("projections executed") + return nil } func execProjections(ctx context.Context, instances <-chan string, failedInstances chan<- string, wg *sync.WaitGroup) { diff --git a/cmd/mirror/system.go b/cmd/mirror/system.go index e16836aa8c..f487f16c4b 100644 --- a/cmd/mirror/system.go +++ b/cmd/mirror/system.go @@ -3,6 +3,7 @@ package mirror import ( "context" _ "embed" + "fmt" "io" "time" @@ -22,9 +23,9 @@ func systemCmd() *cobra.Command { Long: `mirrors the system tables of ZITADEL from one database to another ZITADEL needs to be initialized Only keys and assets are mirrored`, - Run: func(cmd *cobra.Command, args []string) { + RunE: func(cmd *cobra.Command, args []string) error { config := mustNewMigrationConfig(viper.GetViper()) - copySystem(cmd.Context(), config) + return copySystem(cmd.Context(), config) }, } @@ -33,24 +34,35 @@ Only keys and assets are mirrored`, return cmd } -func copySystem(ctx context.Context, config *Migration) { +func copySystem(ctx context.Context, config *Migration) error { sourceClient, err := database.Connect(config.Source, false, dialect.DBPurposeQuery) - logging.OnError(err).Fatal("unable to connect to source database") + if err != nil { + return fmt.Errorf("unable to connect to source database: %w", err) + } defer sourceClient.Close() destClient, err := database.Connect(config.Destination, false, dialect.DBPurposeEventPusher) - logging.OnError(err).Fatal("unable to connect to destination database") + if err != nil { + return fmt.Errorf("unable to connect to destination database: %w", err) + } defer destClient.Close() - copyAssets(ctx, sourceClient, destClient) - copyEncryptionKeys(ctx, sourceClient, destClient) + if err = copyAssets(ctx, sourceClient, destClient); err != nil { + return fmt.Errorf("unable to copy assets: %w", err) + } + if err = copyEncryptionKeys(ctx, sourceClient, destClient); err != nil { + return fmt.Errorf("unable to copy encryption keys: %w", err) + } + return nil } -func copyAssets(ctx context.Context, source, dest *database.DB) { +func copyAssets(ctx context.Context, source, dest *database.DB) error { start := time.Now() sourceConn, err := source.Conn(ctx) - logging.OnError(err).Fatal("unable to acquire source connection") + if err != nil { + return fmt.Errorf("unable to acquire source connection: %w", err) + } defer sourceConn.Close() r, w := io.Pipe() @@ -68,7 +80,9 @@ func copyAssets(ctx context.Context, source, dest *database.DB) { }() destConn, err := dest.Conn(ctx) - logging.OnError(err).Fatal("unable to acquire dest connection") + if err != nil { + return fmt.Errorf("unable to acquire dest connection: %w", err) + } defer destConn.Close() var eventCount int64 @@ -87,16 +101,23 @@ func copyAssets(ctx context.Context, source, dest *database.DB) { return err }) - logging.OnError(err).Fatal("unable to copy assets to destination") - logging.OnError(<-errs).Fatal("unable to copy assets from source") + if err != nil { + return fmt.Errorf("unable to copy assets to destination: %w", err) + } + if err = <-errs; err != nil { + return fmt.Errorf("unable to copy assets from source: %w", err) + } logging.WithFields("took", time.Since(start), "count", eventCount).Info("assets migrated") + return nil } -func copyEncryptionKeys(ctx context.Context, source, dest *database.DB) { +func copyEncryptionKeys(ctx context.Context, source, dest *database.DB) error { start := time.Now() sourceConn, err := source.Conn(ctx) - logging.OnError(err).Fatal("unable to acquire source connection") + if err != nil { + return fmt.Errorf("unable to acquire source connection: %w", err) + } defer sourceConn.Close() r, w := io.Pipe() @@ -114,7 +135,9 @@ func copyEncryptionKeys(ctx context.Context, source, dest *database.DB) { }() destConn, err := dest.Conn(ctx) - logging.OnError(err).Fatal("unable to acquire dest connection") + if err != nil { + return fmt.Errorf("unable to acquire dest connection: %w", err) + } defer destConn.Close() var eventCount int64 @@ -133,7 +156,12 @@ func copyEncryptionKeys(ctx context.Context, source, dest *database.DB) { return err }) - logging.OnError(err).Fatal("unable to copy encryption keys to destination") - logging.OnError(<-errs).Fatal("unable to copy encryption keys from source") + if err != nil { + return fmt.Errorf("unable to copy encryption keys to destination: %w", err) + } + if err = <-errs; err != nil { + return fmt.Errorf("unable to copy encryption keys from source: %w", err) + } logging.WithFields("took", time.Since(start), "count", eventCount).Info("encryption keys migrated") + return nil } diff --git a/cmd/mirror/verify.go b/cmd/mirror/verify.go index 7b90ad89aa..24bdcd76e8 100644 --- a/cmd/mirror/verify.go +++ b/cmd/mirror/verify.go @@ -33,17 +33,31 @@ var schemas = []string{ "system", } -func verifyMigration(ctx context.Context, config *Migration) { +func verifyMigration(ctx context.Context, config *Migration) error { sourceClient, err := database.Connect(config.Source, false, dialect.DBPurposeQuery) - logging.OnError(err).Fatal("unable to connect to source database") + if err != nil { + return fmt.Errorf("unable to connect to source database: %w", err) + } defer sourceClient.Close() destClient, err := database.Connect(config.Destination, false, dialect.DBPurposeEventPusher) - logging.OnError(err).Fatal("unable to connect to destination database") + if err != nil { + return fmt.Errorf("unable to connect to destination database: %w", err) + } defer destClient.Close() for _, schema := range schemas { - for _, table := range append(getTables(ctx, destClient, schema), getViews(ctx, destClient, schema)...) { + var tables []string + tables, err = getTables(ctx, destClient, schema) + if err != nil { + return fmt.Errorf("unable to get tables: %w", err) + } + var views []string + views, err = getViews(ctx, destClient, schema) + if err != nil { + return fmt.Errorf("unable to get views: %w", err) + } + for _, table := range append(tables, views...) { sourceCount := countEntries(ctx, sourceClient, table) destCount := countEntries(ctx, destClient, table) @@ -55,10 +69,11 @@ func verifyMigration(ctx context.Context, config *Migration) { entry.WithField("diff", destCount-sourceCount).Info("unequal count") } } + return nil } -func getTables(ctx context.Context, dest *database.DB, schema string) (tables []string) { - err := dest.QueryContext( +func getTables(ctx context.Context, dest *database.DB, schema string) (tables []string, err error) { + err = dest.QueryContext( ctx, func(r *sql.Rows) error { for r.Next() { @@ -73,12 +88,14 @@ func getTables(ctx context.Context, dest *database.DB, schema string) (tables [] "SELECT CONCAT(schemaname, '.', tablename) FROM pg_tables WHERE schemaname = $1", schema, ) - logging.WithFields("schema", schema).OnError(err).Fatal("unable to query tables") - return tables + if err != nil { + return nil, fmt.Errorf("unable to query tables: %w", err) + } + return tables, nil } -func getViews(ctx context.Context, dest *database.DB, schema string) (tables []string) { - err := dest.QueryContext( +func getViews(ctx context.Context, dest *database.DB, schema string) (tables []string, err error) { + err = dest.QueryContext( ctx, func(r *sql.Rows) error { for r.Next() { @@ -93,8 +110,10 @@ func getViews(ctx context.Context, dest *database.DB, schema string) (tables []s "SELECT CONCAT(schemaname, '.', viewname) FROM pg_views WHERE schemaname = $1", schema, ) - logging.WithFields("schema", schema).OnError(err).Fatal("unable to query views") - return tables + if err != nil { + return nil, fmt.Errorf("unable to query views in schema %s: %w", schema, err) + } + return tables, nil } func countEntries(ctx context.Context, client *database.DB, table string) (count int) { diff --git a/cmd/ready/ready.go b/cmd/ready/ready.go index 02fd5759a2..a46c65e2b3 100644 --- a/cmd/ready/ready.go +++ b/cmd/ready/ready.go @@ -6,7 +6,7 @@ import ( "github.com/spf13/cobra" "github.com/spf13/viper" "github.com/zitadel/logging" - "github.com/zitadel/zitadel/internal/socket" + "github.com/zitadel/zitadel/internal/unixsocket" "net" "net/http" "os" @@ -33,7 +33,7 @@ func ready(config *Config) bool { logging.Info("ready check passed") return true } - socketErr := expectTrueFromSocket(socket.ReadinessQuery) + socketErr := expectTrueFromSocket(unixsocket.ReadinessQuery) if socketErr == nil { logging.Info("ready check passed") return true @@ -43,12 +43,12 @@ func ready(config *Config) bool { return false } -func expectTrueFromSocket(query socket.SocketRequest) error { +func expectTrueFromSocket(query unixsocket.SocketRequest) error { resp, err := query.Request() if err != nil { return fmt.Errorf("socket request error: %w", err) } - if resp != socket.True { + if resp != unixsocket.True { return fmt.Errorf("zitadel process did not respond true to a readiness query") } return nil diff --git a/cmd/setup/setup.go b/cmd/setup/setup.go index d7455e8e22..49e482f7f8 100644 --- a/cmd/setup/setup.go +++ b/cmd/setup/setup.go @@ -4,7 +4,7 @@ import ( "context" "embed" _ "embed" - "github.com/zitadel/zitadel/internal/socket" + "github.com/zitadel/zitadel/internal/unixsocket" "net/http" "github.com/spf13/cobra" @@ -54,7 +54,7 @@ func New() *cobra.Command { Requirements: - cockroachdb`, Run: func(cmd *cobra.Command, args []string) { - closeSocket, err := socket.ListenAndIgnore() + closeSocket, err := unixsocket.ListenAndIgnore() logging.OnError(err).Fatal("unable to listen on socket") defer closeSocket() diff --git a/cmd/start/start.go b/cmd/start/start.go index f9dfc4e0a0..9a9da0952c 100644 --- a/cmd/start/start.go +++ b/cmd/start/start.go @@ -5,7 +5,7 @@ import ( "crypto/tls" _ "embed" "fmt" - "github.com/zitadel/zitadel/internal/socket" + "github.com/zitadel/zitadel/internal/unixsocket" "math" "net/http" "os" @@ -621,16 +621,16 @@ func checkExisting(values []string) func(string) bool { } func listenSocket(ctx context.Context) (chan<- *Server, func() error, error) { - return socket.Listen(func(server *Server, request socket.SocketRequest) (socket.SocketResponse, error) { + return unixsocket.Listen(func(server *Server, request unixsocket.SocketRequest) (unixsocket.SocketResponse, error) { switch request { - case socket.ReadinessQuery: + case unixsocket.ReadinessQuery: if readyErr := server.Queries.Health(ctx); readyErr != nil { logging.Warnf("readiness check failed: %v", readyErr) - return socket.False, nil + return unixsocket.False, nil } - return socket.True, nil + return unixsocket.True, nil default: - return socket.UnknownRequest, fmt.Errorf("unknown request: %d", request) + return unixsocket.UnknownRequest, fmt.Errorf("unknown request: %d", request) } }) } diff --git a/internal/socket/listen.go b/internal/unixsocket/listen.go similarity index 98% rename from internal/socket/listen.go rename to internal/unixsocket/listen.go index c1a6ce9692..b45ca084ab 100644 --- a/internal/socket/listen.go +++ b/internal/unixsocket/listen.go @@ -1,4 +1,4 @@ -package socket +package unixsocket import ( "fmt" diff --git a/internal/socket/socket.go b/internal/unixsocket/socket.go similarity index 98% rename from internal/socket/socket.go rename to internal/unixsocket/socket.go index 7918f3bc65..85037c7100 100644 --- a/internal/socket/socket.go +++ b/internal/unixsocket/socket.go @@ -1,4 +1,4 @@ -package socket +package unixsocket import ( "fmt"