mirror error handling

This commit is contained in:
Elio Bischof 2024-09-26 10:31:43 +02:00
parent c5ee9464f1
commit 42a5c6147e
No known key found for this signature in database
GPG Key ID: 7B383FDE4DDBF1BD
12 changed files with 272 additions and 148 deletions

View File

@ -6,10 +6,9 @@ import (
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/socket"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/dialect"
"github.com/zitadel/zitadel/internal/unixsocket"
)
var (
@ -72,7 +71,7 @@ func InitAll(ctx context.Context, config *Config) {
}
func initialise(config database.Config, steps ...func(*database.DB) error) error {
closeSocket, err := socket.ListenAndIgnore()
closeSocket, err := unixsocket.ListenAndIgnore()
logging.OnError(err).Fatal("unable to listen on socket")
defer closeSocket()

View File

@ -3,6 +3,7 @@ package mirror
import (
"context"
_ "embed"
"fmt"
"io"
"time"
@ -33,23 +34,29 @@ Only auth requests are mirrored`,
return cmd
}
func copyAuth(ctx context.Context, config *Migration) {
func copyAuth(ctx context.Context, config *Migration) error {
sourceClient, err := database.Connect(config.Source, false, dialect.DBPurposeQuery)
logging.OnError(err).Fatal("unable to connect to source database")
if err != nil {
return fmt.Errorf("unable to connect to source database: %w", err)
}
defer sourceClient.Close()
destClient, err := database.Connect(config.Destination, false, dialect.DBPurposeEventPusher)
logging.OnError(err).Fatal("unable to connect to destination database")
if err != nil {
return fmt.Errorf("unable to connect to destination database: %w", err)
}
defer destClient.Close()
copyAuthRequests(ctx, sourceClient, destClient)
return copyAuthRequests(ctx, sourceClient, destClient)
}
func copyAuthRequests(ctx context.Context, source, dest *database.DB) {
func copyAuthRequests(ctx context.Context, source, dest *database.DB) error {
start := time.Now()
sourceConn, err := source.Conn(ctx)
logging.OnError(err).Fatal("unable to acquire connection")
if err != nil {
return fmt.Errorf("unable to acquire connection: %w", err)
}
defer sourceConn.Close()
r, w := io.Pipe()
@ -66,7 +73,9 @@ func copyAuthRequests(ctx context.Context, source, dest *database.DB) {
}()
destConn, err := dest.Conn(ctx)
logging.OnError(err).Fatal("unable to acquire connection")
if err != nil {
return fmt.Errorf("unable to acquire connection: %w", err)
}
defer destConn.Close()
var affected int64
@ -85,7 +94,12 @@ func copyAuthRequests(ctx context.Context, source, dest *database.DB) {
return err
})
logging.OnError(err).Fatal("unable to copy auth requests to destination")
logging.OnError(<-errs).Fatal("unable to copy auth requests from source")
if err != nil {
return fmt.Errorf("unable to copy auth requests to destination: %w", err)
}
if err = <-errs; err != nil {
return fmt.Errorf("unable to copy auth requests from source: %w", err)
}
logging.WithFields("took", time.Since(start), "count", affected).Info("auth requests migrated")
return nil
}

View File

@ -5,6 +5,7 @@ import (
"database/sql"
_ "embed"
"errors"
"fmt"
"io"
"time"
@ -32,9 +33,9 @@ func eventstoreCmd() *cobra.Command {
Long: `mirrors the eventstore of an instance from one database to another
ZITADEL needs to be initialized and set up with the --for-mirror flag
Migrate only copies events2 and unique constraints`,
Run: func(cmd *cobra.Command, args []string) {
RunE: func(cmd *cobra.Command, args []string) error {
config := mustNewMigrationConfig(viper.GetViper())
copyEventstore(cmd.Context(), config)
return copyEventstore(cmd.Context(), config)
},
}
@ -44,17 +45,26 @@ Migrate only copies events2 and unique constraints`,
return cmd
}
func copyEventstore(ctx context.Context, config *Migration) {
func copyEventstore(ctx context.Context, config *Migration) error {
sourceClient, err := db.Connect(config.Source, false, dialect.DBPurposeEventPusher)
logging.OnError(err).Fatal("unable to connect to source database")
if err != nil {
return fmt.Errorf("unable to connect to source database: %w", err)
}
defer sourceClient.Close()
destClient, err := db.Connect(config.Destination, false, dialect.DBPurposeEventPusher)
logging.OnError(err).Fatal("unable to connect to destination database")
if err != nil {
return fmt.Errorf("unable to connect to destination database: %w", err)
}
defer destClient.Close()
copyEvents(ctx, sourceClient, destClient, config.EventBulkSize)
copyUniqueConstraints(ctx, sourceClient, destClient)
if err = copyEvents(ctx, sourceClient, destClient, config.EventBulkSize); err != nil {
return fmt.Errorf("unable to copy events: %w", err)
}
if err = copyUniqueConstraints(ctx, sourceClient, destClient); err != nil {
return fmt.Errorf("unable to copy unique constraints: %w", err)
}
return nil
}
func positionQuery(db *db.DB) string {
@ -69,19 +79,21 @@ func positionQuery(db *db.DB) string {
}
}
func copyEvents(ctx context.Context, source, dest *db.DB, bulkSize uint32) {
func copyEvents(ctx context.Context, source, dest *db.DB, bulkSize uint32) error {
start := time.Now()
reader, writer := io.Pipe()
migrationID, err := id.SonyFlakeGenerator().Next()
logging.OnError(err).Fatal("unable to generate migration id")
if err != nil {
return fmt.Errorf("unable to generate migration id: %w", err)
}
sourceConn, err := source.Conn(ctx)
logging.OnError(err).Fatal("unable to acquire source connection")
if err != nil {
return fmt.Errorf("unable to acquire source connection: %w", err)
}
destConn, err := dest.Conn(ctx)
logging.OnError(err).Fatal("unable to acquire dest connection")
if err != nil {
return fmt.Errorf("unable to acquire dest connection: %w", err)
}
sourceES := eventstore.NewEventstoreFromOne(postgres.New(source, &postgres.Config{
MaxRetries: 3,
}))
@ -90,11 +102,14 @@ func copyEvents(ctx context.Context, source, dest *db.DB, bulkSize uint32) {
}))
previousMigration, err := queryLastSuccessfulMigration(ctx, destinationES, source.DatabaseName())
logging.OnError(err).Fatal("unable to query latest successful migration")
if err != nil {
return fmt.Errorf("unable to query latest successful migration: %w", err)
}
maxPosition, err := writeMigrationStart(ctx, sourceES, migrationID, dest.DatabaseName())
logging.OnError(err).Fatal("unable to write migration started event")
if err != nil {
return fmt.Errorf("unable to write migration start: %w", err)
}
logging.WithFields("from", previousMigration.Position, "to", maxPosition).Info("start event migration")
nextPos := make(chan bool, 1)
@ -102,7 +117,7 @@ func copyEvents(ctx context.Context, source, dest *db.DB, bulkSize uint32) {
errs := make(chan error, 3)
go func() {
err := sourceConn.Raw(func(driverConn interface{}) error {
goErr := sourceConn.Raw(func(driverConn interface{}) error {
conn := driverConn.(*stdlib.Conn).Conn()
nextPos <- true
var i uint32
@ -139,7 +154,7 @@ func copyEvents(ctx context.Context, source, dest *db.DB, bulkSize uint32) {
})
writer.Close()
close(nextPos)
errs <- err
errs <- goErr
}()
// generate next position for
@ -147,41 +162,41 @@ func copyEvents(ctx context.Context, source, dest *db.DB, bulkSize uint32) {
defer close(pos)
for range nextPos {
var position float64
err := dest.QueryRowContext(
goErr := dest.QueryRowContext(
ctx,
func(row *sql.Row) error {
return row.Scan(&position)
},
positionQuery(dest),
)
if err != nil {
if goErr != nil {
errs <- zerrors.ThrowUnknown(err, "MIGRA-kMyPH", "unable to query next position")
return
}
pos <- position
}
}()
var eventCount int64
errs <- destConn.Raw(func(driverConn interface{}) error {
conn := driverConn.(*stdlib.Conn).Conn()
tag, err := conn.PgConn().CopyFrom(ctx, reader, "COPY eventstore.events2 FROM STDIN")
tag, cbErr := conn.PgConn().CopyFrom(ctx, reader, "COPY eventstore.events2 FROM STDIN")
eventCount = tag.RowsAffected()
if err != nil {
return zerrors.ThrowUnknown(err, "MIGRA-DTHi7", "unable to copy events into destination")
if cbErr != nil {
return zerrors.ThrowUnknown(cbErr, "MIGRA-DTHi7", "unable to copy events into destination")
}
return nil
})
close(errs)
writeCopyEventsDone(ctx, destinationES, migrationID, source.DatabaseName(), maxPosition, errs)
if err = writeCopyEventsDone(ctx, destinationES, migrationID, source.DatabaseName(), maxPosition, errs); err != nil {
return fmt.Errorf("unable to write migration done: %w", err)
}
logging.WithFields("took", time.Since(start), "count", eventCount).Info("events migrated")
return nil
}
func writeCopyEventsDone(ctx context.Context, es *eventstore.EventStore, id, source string, position decimal.Decimal, errs <-chan error) {
func writeCopyEventsDone(ctx context.Context, es *eventstore.EventStore, id, source string, position decimal.Decimal, errs <-chan error) error {
joinedErrs := make([]error, 0, len(errs))
for err := range errs {
joinedErrs = append(joinedErrs, err)
@ -190,25 +205,31 @@ func writeCopyEventsDone(ctx context.Context, es *eventstore.EventStore, id, sou
if err != nil {
logging.WithError(err).Error("unable to mirror events")
err := writeMigrationFailed(ctx, es, id, source, err)
logging.OnError(err).Fatal("unable to write failed event")
return
err = writeMigrationFailed(ctx, es, id, source, err)
if err != nil {
return fmt.Errorf("unable to write failed event: %w", err)
}
return nil
}
err = writeMigrationSucceeded(ctx, es, id, source, position)
logging.OnError(err).Fatal("unable to write failed event")
if err = writeMigrationSucceeded(ctx, es, id, source, position); err != nil {
return fmt.Errorf("unable to write succeeded event: %w", err)
}
return nil
}
func copyUniqueConstraints(ctx context.Context, source, dest *db.DB) {
func copyUniqueConstraints(ctx context.Context, source, dest *db.DB) error {
start := time.Now()
reader, writer := io.Pipe()
errs := make(chan error, 1)
sourceConn, err := source.Conn(ctx)
logging.OnError(err).Fatal("unable to acquire source connection")
if err != nil {
return fmt.Errorf("unable to acquire source connection: %w", err)
}
go func() {
err := sourceConn.Raw(func(driverConn interface{}) error {
errs <- sourceConn.Raw(func(driverConn interface{}) error {
conn := driverConn.(*stdlib.Conn).Conn()
var stmt database.Statement
stmt.WriteString("COPY (SELECT instance_id, unique_type, unique_field FROM eventstore.unique_constraints ")
@ -219,11 +240,12 @@ func copyUniqueConstraints(ctx context.Context, source, dest *db.DB) {
writer.Close()
return err
})
errs <- err
}()
destConn, err := dest.Conn(ctx)
logging.OnError(err).Fatal("unable to acquire dest connection")
if err != nil {
return fmt.Errorf("unable to acquire dest connection: %w", err)
}
var eventCount int64
err = destConn.Raw(func(driverConn interface{}) error {
@ -234,18 +256,22 @@ func copyUniqueConstraints(ctx context.Context, source, dest *db.DB) {
stmt.WriteString("DELETE FROM eventstore.unique_constraints ")
stmt.WriteString(instanceClause())
_, err := conn.Exec(ctx, stmt.String())
if err != nil {
return err
_, cbErr := conn.Exec(ctx, stmt.String())
if cbErr != nil {
return cbErr
}
}
tag, err := conn.PgConn().CopyFrom(ctx, reader, "COPY eventstore.unique_constraints FROM stdin")
tag, cbErr := conn.PgConn().CopyFrom(ctx, reader, "COPY eventstore.unique_constraints FROM stdin")
eventCount = tag.RowsAffected()
return err
return cbErr
})
logging.OnError(err).Fatal("unable to copy unique constraints to destination")
logging.OnError(<-errs).Fatal("unable to copy unique constraints from source")
if err != nil {
return fmt.Errorf("unable to copy unique constraints to destination: %w", err)
}
if err = <-errs; err != nil {
return fmt.Errorf("unable to copy unique constraints from source: %w", err)
}
logging.WithFields("took", time.Since(start), "count", eventCount).Info("unique constraints migrated")
return nil
}

View File

@ -3,12 +3,13 @@ package mirror
import (
"bytes"
_ "embed"
"errors"
"fmt"
"github.com/zitadel/zitadel/internal/unixsocket"
"strings"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/cmd/key"
)
@ -19,6 +20,7 @@ var (
)
func New(configFiles *[]string) *cobra.Command {
closeSocket := func() error { return nil }
cmd := &cobra.Command{
Use: "mirror",
Short: "mirrors all data of ZITADEL from one database to another",
@ -34,29 +36,49 @@ Order of execution:
3. mirror event store tables
4. recompute projections
5. verify`,
PersistentPreRun: func(cmd *cobra.Command, args []string) {
err := viper.MergeConfig(bytes.NewBuffer(defaultConfig))
logging.OnError(err).Fatal("unable to read default config")
PersistentPreRunE: func(cmd *cobra.Command, args []string) (err error) {
closeSocket, err = unixsocket.ListenAndIgnore()
if err != nil {
return fmt.Errorf("unable to listen on socket: %w", err)
}
if err = viper.MergeConfig(bytes.NewBuffer(defaultConfig)); err != nil {
return errors.New("unable to read default config")
}
for _, file := range *configFiles {
viper.SetConfigFile(file)
err := viper.MergeInConfig()
logging.WithFields("file", file).OnError(err).Warn("unable to read config file")
if err = viper.MergeInConfig(); err != nil {
return fmt.Errorf("unable to read config file: %w", err)
}
}
return nil
},
Run: func(cmd *cobra.Command, args []string) {
RunE: func(cmd *cobra.Command, args []string) error {
defer closeSocket()
config := mustNewMigrationConfig(viper.GetViper())
projectionConfig := mustNewProjectionsConfig(viper.GetViper())
masterKey, err := key.MasterKey(cmd)
logging.OnError(err).Fatal("unable to read master key")
if err != nil {
return fmt.Errorf("unable to read master key: %w", err)
}
if err = copySystem(cmd.Context(), config); err != nil {
return fmt.Errorf("unable to copy system tables: %w", err)
}
if err = copyAuth(cmd.Context(), config); err != nil {
return fmt.Errorf("unable to copy auth tables: %w", err)
}
if err = copyEventstore(cmd.Context(), config); err != nil {
return fmt.Errorf("unable to copy eventstore tables: %w", err)
}
copySystem(cmd.Context(), config)
copyAuth(cmd.Context(), config)
copyEventstore(cmd.Context(), config)
projections(cmd.Context(), projectionConfig, masterKey)
verifyMigration(cmd.Context(), config)
if err = projections(cmd.Context(), projectionConfig, masterKey); err != nil {
return fmt.Errorf("unable to recompute projections: %w", err)
}
if err = verifyMigration(cmd.Context(), config); err != nil {
return fmt.Errorf("unable to verify migration: %w", err)
}
return nil
},
}
@ -70,11 +92,11 @@ The flag should be provided if you want to execute the mirror command multiple t
migrateProjectionsFlags(cmd)
cmd.AddCommand(
eventstoreCmd(),
systemCmd(),
projectionsCmd(),
authCmd(),
verifyCmd(),
eventstoreCmd(closeSocket),
systemCmd(closeSocket),
projectionsCmd(closeSocket),
authCmd(closeSocket),
verifyCmd(closeSocket),
)
return cmd

View File

@ -3,6 +3,7 @@ package mirror
import (
"context"
"database/sql"
"fmt"
"net/http"
"sync"
"time"
@ -50,13 +51,13 @@ func projectionsCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "projections",
Short: "calls the projections synchronously",
Run: func(cmd *cobra.Command, args []string) {
RunE: func(cmd *cobra.Command, args []string) error {
config := mustNewProjectionsConfig(viper.GetViper())
masterKey, err := key.MasterKey(cmd)
logging.OnError(err).Fatal("unable to read master key")
projections(cmd.Context(), config, masterKey)
if err != nil {
return fmt.Errorf("unable to read master key: %w", err)
}
return projections(cmd.Context(), config, masterKey)
},
}
@ -100,24 +101,30 @@ func projections(
ctx context.Context,
config *ProjectionsConfig,
masterKey string,
) {
) error {
start := time.Now()
client, err := database.Connect(config.Destination, false, dialect.DBPurposeQuery)
logging.OnError(err).Fatal("unable to connect to database")
if err != nil {
return fmt.Errorf("unable to connect to database: %w", err)
}
keyStorage, err := crypto_db.NewKeyStorage(client, masterKey)
logging.OnError(err).Fatal("cannot start key storage")
if err != nil {
return fmt.Errorf("unable to start key storage: %w", err)
}
keys, err := encryption.EnsureEncryptionKeys(ctx, config.EncryptionKeys, keyStorage)
logging.OnError(err).Fatal("unable to read encryption keys")
if err != nil {
return fmt.Errorf("unable to read encryption keys: %w", err)
}
staticStorage, err := config.AssetStorage.NewStorage(client.DB)
logging.OnError(err).Fatal("unable create static storage")
if err != nil {
return fmt.Errorf("unable create static storage: %w", err)
}
config.Eventstore.Querier = old_es.NewCRDB(client)
esPusherDBClient, err := database.Connect(config.Destination, false, dialect.DBPurposeEventPusher)
logging.OnError(err).Fatal("unable to connect eventstore push client")
if err != nil {
return fmt.Errorf("unable to connect eventstore push client: %w", err)
}
config.Eventstore.Pusher = new_es.NewEventstore(esPusherDBClient)
es := eventstore.NewEventstore(config.Eventstore)
esV4 := es_v4.NewEventstoreFromOne(es_v4_pg.New(client, &es_v4_pg.Config{
@ -149,11 +156,14 @@ func projections(
config.SystemAPIUsers,
false,
)
logging.OnError(err).Fatal("unable to start queries")
if err != nil {
return fmt.Errorf("unable to start queries: %w", err)
}
authZRepo, err := authz.Start(queries, es, client, keys.OIDC, config.ExternalSecure)
logging.OnError(err).Fatal("unable to start authz repo")
if err != nil {
return fmt.Errorf("unable to start authz repo: %w", err)
}
webAuthNConfig := &webauthn.Config{
DisplayName: config.WebAuthNName,
ExternalSecure: config.ExternalSecure,
@ -185,13 +195,14 @@ func projections(
config.OIDC.DefaultRefreshTokenIdleExpiration,
config.DefaultInstance.SecretGenerators,
)
logging.OnError(err).Fatal("unable to start commands")
if err != nil {
return fmt.Errorf("unable to start commands: %w", err)
}
err = projection.Create(ctx, client, es, config.Projections, keys.OIDC, keys.SAML, config.SystemAPIUsers)
logging.OnError(err).Fatal("unable to start projections")
if err != nil {
return fmt.Errorf("unable to create projections: %w", err)
}
i18n.MustLoadSupportedLanguagesFromDir()
notification.Register(
ctx,
config.Projections.Customizations["notifications"],
@ -214,13 +225,17 @@ func projections(
config.Auth.Spooler.Client = client
config.Auth.Spooler.Eventstore = es
authView, err := auth_view.StartView(config.Auth.Spooler.Client, keys.OIDC, queries, config.Auth.Spooler.Eventstore)
logging.OnError(err).Fatal("unable to start auth view")
if err != nil {
return fmt.Errorf("unable to start auth view: %w", err)
}
auth_handler.Register(ctx, config.Auth.Spooler, authView, queries)
config.Admin.Spooler.Client = client
config.Admin.Spooler.Eventstore = es
adminView, err := admin_view.StartView(config.Admin.Spooler.Client)
logging.OnError(err).Fatal("unable to start admin view")
if err != nil {
return fmt.Errorf("unable to start admin view: %w", err)
}
admin_handler.Register(ctx, config.Admin.Spooler, adminView, staticStorage)
@ -248,6 +263,7 @@ func projections(
close(failedInstances)
logging.WithFields("took", time.Since(start)).Info("projections executed")
return nil
}
func execProjections(ctx context.Context, instances <-chan string, failedInstances chan<- string, wg *sync.WaitGroup) {

View File

@ -3,6 +3,7 @@ package mirror
import (
"context"
_ "embed"
"fmt"
"io"
"time"
@ -22,9 +23,9 @@ func systemCmd() *cobra.Command {
Long: `mirrors the system tables of ZITADEL from one database to another
ZITADEL needs to be initialized
Only keys and assets are mirrored`,
Run: func(cmd *cobra.Command, args []string) {
RunE: func(cmd *cobra.Command, args []string) error {
config := mustNewMigrationConfig(viper.GetViper())
copySystem(cmd.Context(), config)
return copySystem(cmd.Context(), config)
},
}
@ -33,24 +34,35 @@ Only keys and assets are mirrored`,
return cmd
}
func copySystem(ctx context.Context, config *Migration) {
func copySystem(ctx context.Context, config *Migration) error {
sourceClient, err := database.Connect(config.Source, false, dialect.DBPurposeQuery)
logging.OnError(err).Fatal("unable to connect to source database")
if err != nil {
return fmt.Errorf("unable to connect to source database: %w", err)
}
defer sourceClient.Close()
destClient, err := database.Connect(config.Destination, false, dialect.DBPurposeEventPusher)
logging.OnError(err).Fatal("unable to connect to destination database")
if err != nil {
return fmt.Errorf("unable to connect to destination database: %w", err)
}
defer destClient.Close()
copyAssets(ctx, sourceClient, destClient)
copyEncryptionKeys(ctx, sourceClient, destClient)
if err = copyAssets(ctx, sourceClient, destClient); err != nil {
return fmt.Errorf("unable to copy assets: %w", err)
}
if err = copyEncryptionKeys(ctx, sourceClient, destClient); err != nil {
return fmt.Errorf("unable to copy encryption keys: %w", err)
}
return nil
}
func copyAssets(ctx context.Context, source, dest *database.DB) {
func copyAssets(ctx context.Context, source, dest *database.DB) error {
start := time.Now()
sourceConn, err := source.Conn(ctx)
logging.OnError(err).Fatal("unable to acquire source connection")
if err != nil {
return fmt.Errorf("unable to acquire source connection: %w", err)
}
defer sourceConn.Close()
r, w := io.Pipe()
@ -68,7 +80,9 @@ func copyAssets(ctx context.Context, source, dest *database.DB) {
}()
destConn, err := dest.Conn(ctx)
logging.OnError(err).Fatal("unable to acquire dest connection")
if err != nil {
return fmt.Errorf("unable to acquire dest connection: %w", err)
}
defer destConn.Close()
var eventCount int64
@ -87,16 +101,23 @@ func copyAssets(ctx context.Context, source, dest *database.DB) {
return err
})
logging.OnError(err).Fatal("unable to copy assets to destination")
logging.OnError(<-errs).Fatal("unable to copy assets from source")
if err != nil {
return fmt.Errorf("unable to copy assets to destination: %w", err)
}
if err = <-errs; err != nil {
return fmt.Errorf("unable to copy assets from source: %w", err)
}
logging.WithFields("took", time.Since(start), "count", eventCount).Info("assets migrated")
return nil
}
func copyEncryptionKeys(ctx context.Context, source, dest *database.DB) {
func copyEncryptionKeys(ctx context.Context, source, dest *database.DB) error {
start := time.Now()
sourceConn, err := source.Conn(ctx)
logging.OnError(err).Fatal("unable to acquire source connection")
if err != nil {
return fmt.Errorf("unable to acquire source connection: %w", err)
}
defer sourceConn.Close()
r, w := io.Pipe()
@ -114,7 +135,9 @@ func copyEncryptionKeys(ctx context.Context, source, dest *database.DB) {
}()
destConn, err := dest.Conn(ctx)
logging.OnError(err).Fatal("unable to acquire dest connection")
if err != nil {
return fmt.Errorf("unable to acquire dest connection: %w", err)
}
defer destConn.Close()
var eventCount int64
@ -133,7 +156,12 @@ func copyEncryptionKeys(ctx context.Context, source, dest *database.DB) {
return err
})
logging.OnError(err).Fatal("unable to copy encryption keys to destination")
logging.OnError(<-errs).Fatal("unable to copy encryption keys from source")
if err != nil {
return fmt.Errorf("unable to copy encryption keys to destination: %w", err)
}
if err = <-errs; err != nil {
return fmt.Errorf("unable to copy encryption keys from source: %w", err)
}
logging.WithFields("took", time.Since(start), "count", eventCount).Info("encryption keys migrated")
return nil
}

View File

@ -33,17 +33,31 @@ var schemas = []string{
"system",
}
func verifyMigration(ctx context.Context, config *Migration) {
func verifyMigration(ctx context.Context, config *Migration) error {
sourceClient, err := database.Connect(config.Source, false, dialect.DBPurposeQuery)
logging.OnError(err).Fatal("unable to connect to source database")
if err != nil {
return fmt.Errorf("unable to connect to source database: %w", err)
}
defer sourceClient.Close()
destClient, err := database.Connect(config.Destination, false, dialect.DBPurposeEventPusher)
logging.OnError(err).Fatal("unable to connect to destination database")
if err != nil {
return fmt.Errorf("unable to connect to destination database: %w", err)
}
defer destClient.Close()
for _, schema := range schemas {
for _, table := range append(getTables(ctx, destClient, schema), getViews(ctx, destClient, schema)...) {
var tables []string
tables, err = getTables(ctx, destClient, schema)
if err != nil {
return fmt.Errorf("unable to get tables: %w", err)
}
var views []string
views, err = getViews(ctx, destClient, schema)
if err != nil {
return fmt.Errorf("unable to get views: %w", err)
}
for _, table := range append(tables, views...) {
sourceCount := countEntries(ctx, sourceClient, table)
destCount := countEntries(ctx, destClient, table)
@ -55,10 +69,11 @@ func verifyMigration(ctx context.Context, config *Migration) {
entry.WithField("diff", destCount-sourceCount).Info("unequal count")
}
}
return nil
}
func getTables(ctx context.Context, dest *database.DB, schema string) (tables []string) {
err := dest.QueryContext(
func getTables(ctx context.Context, dest *database.DB, schema string) (tables []string, err error) {
err = dest.QueryContext(
ctx,
func(r *sql.Rows) error {
for r.Next() {
@ -73,12 +88,14 @@ func getTables(ctx context.Context, dest *database.DB, schema string) (tables []
"SELECT CONCAT(schemaname, '.', tablename) FROM pg_tables WHERE schemaname = $1",
schema,
)
logging.WithFields("schema", schema).OnError(err).Fatal("unable to query tables")
return tables
if err != nil {
return nil, fmt.Errorf("unable to query tables: %w", err)
}
return tables, nil
}
func getViews(ctx context.Context, dest *database.DB, schema string) (tables []string) {
err := dest.QueryContext(
func getViews(ctx context.Context, dest *database.DB, schema string) (tables []string, err error) {
err = dest.QueryContext(
ctx,
func(r *sql.Rows) error {
for r.Next() {
@ -93,8 +110,10 @@ func getViews(ctx context.Context, dest *database.DB, schema string) (tables []s
"SELECT CONCAT(schemaname, '.', viewname) FROM pg_views WHERE schemaname = $1",
schema,
)
logging.WithFields("schema", schema).OnError(err).Fatal("unable to query views")
return tables
if err != nil {
return nil, fmt.Errorf("unable to query views in schema %s: %w", schema, err)
}
return tables, nil
}
func countEntries(ctx context.Context, client *database.DB, table string) (count int) {

View File

@ -6,7 +6,7 @@ import (
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/socket"
"github.com/zitadel/zitadel/internal/unixsocket"
"net"
"net/http"
"os"
@ -33,7 +33,7 @@ func ready(config *Config) bool {
logging.Info("ready check passed")
return true
}
socketErr := expectTrueFromSocket(socket.ReadinessQuery)
socketErr := expectTrueFromSocket(unixsocket.ReadinessQuery)
if socketErr == nil {
logging.Info("ready check passed")
return true
@ -43,12 +43,12 @@ func ready(config *Config) bool {
return false
}
func expectTrueFromSocket(query socket.SocketRequest) error {
func expectTrueFromSocket(query unixsocket.SocketRequest) error {
resp, err := query.Request()
if err != nil {
return fmt.Errorf("socket request error: %w", err)
}
if resp != socket.True {
if resp != unixsocket.True {
return fmt.Errorf("zitadel process did not respond true to a readiness query")
}
return nil

View File

@ -4,7 +4,7 @@ import (
"context"
"embed"
_ "embed"
"github.com/zitadel/zitadel/internal/socket"
"github.com/zitadel/zitadel/internal/unixsocket"
"net/http"
"github.com/spf13/cobra"
@ -54,7 +54,7 @@ func New() *cobra.Command {
Requirements:
- cockroachdb`,
Run: func(cmd *cobra.Command, args []string) {
closeSocket, err := socket.ListenAndIgnore()
closeSocket, err := unixsocket.ListenAndIgnore()
logging.OnError(err).Fatal("unable to listen on socket")
defer closeSocket()

View File

@ -5,7 +5,7 @@ import (
"crypto/tls"
_ "embed"
"fmt"
"github.com/zitadel/zitadel/internal/socket"
"github.com/zitadel/zitadel/internal/unixsocket"
"math"
"net/http"
"os"
@ -621,16 +621,16 @@ func checkExisting(values []string) func(string) bool {
}
func listenSocket(ctx context.Context) (chan<- *Server, func() error, error) {
return socket.Listen(func(server *Server, request socket.SocketRequest) (socket.SocketResponse, error) {
return unixsocket.Listen(func(server *Server, request unixsocket.SocketRequest) (unixsocket.SocketResponse, error) {
switch request {
case socket.ReadinessQuery:
case unixsocket.ReadinessQuery:
if readyErr := server.Queries.Health(ctx); readyErr != nil {
logging.Warnf("readiness check failed: %v", readyErr)
return socket.False, nil
return unixsocket.False, nil
}
return socket.True, nil
return unixsocket.True, nil
default:
return socket.UnknownRequest, fmt.Errorf("unknown request: %d", request)
return unixsocket.UnknownRequest, fmt.Errorf("unknown request: %d", request)
}
})
}

View File

@ -1,4 +1,4 @@
package socket
package unixsocket
import (
"fmt"

View File

@ -1,4 +1,4 @@
package socket
package unixsocket
import (
"fmt"