From 34361c6f827679284d306aafcd4795d17dc08799 Mon Sep 17 00:00:00 2001 From: Mike Poindexter Date: Thu, 29 Aug 2024 23:08:54 -0700 Subject: [PATCH] Fix FKs on sqlite migrations (#2083) --- hscontrol/db/db.go | 79 ++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 69 insertions(+), 10 deletions(-) diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 99c3aa68..accf439e 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -123,21 +123,16 @@ func NewHeadscaleDatabase( } } - // Only run automigrate Route table if it does not exist. It has only been - // changed once, when machines where renamed to nodes, which is covered - // further up. This whole initial integration is a mess and if AutoMigrate - // is ran on a 0.22 to 0.23 update, it will wipe all the routes. + // Remove any invalid routes associated with a node that does not exist. if tx.Migrator().HasTable(&types.Route{}) && tx.Migrator().HasTable(&types.Node{}) { err := tx.Exec("delete from routes where node_id not in (select id from nodes)").Error if err != nil { return err } } - if !tx.Migrator().HasTable(&types.Route{}) { - err = tx.AutoMigrate(&types.Route{}) - if err != nil { - return err - } + err = tx.AutoMigrate(&types.Route{}) + if err != nil { + return err } err = tx.AutoMigrate(&types.Node{}) @@ -421,7 +416,7 @@ func NewHeadscaleDatabase( }, ) - if err = migrations.Migrate(); err != nil { + if err := runMigrations(cfg, dbConn, migrations); err != nil { log.Fatal().Err(err).Msgf("Migration failed: %v", err) } @@ -545,6 +540,70 @@ func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) { ) } +func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormigrate.Gormigrate) error { + // Turn off foreign keys for the duration of the migration if using sqllite to + // prevent data loss due to the way the GORM migrator handles certain schema + // changes. + if cfg.Type == types.DatabaseSqlite { + var fkEnabled int + if err := dbConn.Raw("PRAGMA foreign_keys").Scan(&fkEnabled).Error; err != nil { + return fmt.Errorf("checking foreign key status: %w", err) + } + if fkEnabled == 1 { + if err := dbConn.Exec("PRAGMA foreign_keys = OFF").Error; err != nil { + return fmt.Errorf("disabling foreign keys: %w", err) + } + defer dbConn.Exec("PRAGMA foreign_keys = ON") + } + } + + if err := migrations.Migrate(); err != nil { + return err + } + + // Since we disabled foreign keys for the migration, we need to check for + // constraint violations manually at the end of the migration. + if cfg.Type == types.DatabaseSqlite { + type constraintViolation struct { + Table string + RowID int + Parent string + ConstraintIndex int + } + + var violatedConstraints []constraintViolation + + rows, err := dbConn.Raw("PRAGMA foreign_key_check").Rows() + if err != nil { + return err + } + + for rows.Next() { + var violation constraintViolation + if err := rows.Scan(&violation.Table, &violation.RowID, &violation.Parent, &violation.ConstraintIndex); err != nil { + return err + } + + violatedConstraints = append(violatedConstraints, violation) + } + _ = rows.Close() + + if len(violatedConstraints) > 0 { + for _, violation := range violatedConstraints { + log.Error(). + Str("table", violation.Table). + Int("row_id", violation.RowID). + Str("parent", violation.Parent). + Msg("Foreign key constraint violated") + } + + return fmt.Errorf("foreign key constraints violated") + } + } + + return nil +} + func (hsdb *HSDatabase) PingDB(ctx context.Context) error { ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel()