mirror of
https://github.com/juanfont/headscale.git
synced 2025-08-12 07:47:36 +00:00
Replace database locks with transactions (#1701)
This commits removes the locks used to guard data integrity for the database and replaces them with Transactions, turns out that SQL had a way to deal with this all along. This reduces the complexity we had with multiple locks that might stack or recurse (database, nofitifer, mapper). All notifications and state updates are now triggered _after_ a database change. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
@@ -15,22 +15,25 @@ var (
|
||||
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
|
||||
)
|
||||
|
||||
func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) {
|
||||
return Write(hsdb.DB, func(tx *gorm.DB) (*types.User, error) {
|
||||
return CreateUser(tx, name)
|
||||
})
|
||||
}
|
||||
|
||||
// CreateUser creates a new User. Returns error if could not be created
|
||||
// or another user already exists.
|
||||
func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
func CreateUser(tx *gorm.DB, name string) (*types.User, error) {
|
||||
err := util.CheckForFQDNRules(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user := types.User{}
|
||||
if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil {
|
||||
if err := tx.Where("name = ?", name).First(&user).Error; err == nil {
|
||||
return nil, ErrUserExists
|
||||
}
|
||||
user.Name = name
|
||||
if err := hsdb.db.Create(&user).Error; err != nil {
|
||||
if err := tx.Create(&user).Error; err != nil {
|
||||
log.Error().
|
||||
Str("func", "CreateUser").
|
||||
Err(err).
|
||||
@@ -42,18 +45,21 @@ func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) {
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) DestroyUser(name string) error {
|
||||
return hsdb.Write(func(tx *gorm.DB) error {
|
||||
return DestroyUser(tx, name)
|
||||
})
|
||||
}
|
||||
|
||||
// DestroyUser destroys a User. Returns error if the User does
|
||||
// not exist or if there are nodes associated with it.
|
||||
func (hsdb *HSDatabase) DestroyUser(name string) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
user, err := hsdb.getUser(name)
|
||||
func DestroyUser(tx *gorm.DB, name string) error {
|
||||
user, err := GetUser(tx, name)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
|
||||
nodes, err := hsdb.listNodesByUser(name)
|
||||
nodes, err := ListNodesByUser(tx, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -61,32 +67,35 @@ func (hsdb *HSDatabase) DestroyUser(name string) error {
|
||||
return ErrUserStillHasNodes
|
||||
}
|
||||
|
||||
keys, err := hsdb.listPreAuthKeys(name)
|
||||
keys, err := ListPreAuthKeys(tx, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, key := range keys {
|
||||
err = hsdb.destroyPreAuthKey(key)
|
||||
err = DestroyPreAuthKey(tx, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if result := hsdb.db.Unscoped().Delete(&user); result.Error != nil {
|
||||
if result := tx.Unscoped().Delete(&user); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
|
||||
return hsdb.Write(func(tx *gorm.DB) error {
|
||||
return RenameUser(tx, oldName, newName)
|
||||
})
|
||||
}
|
||||
|
||||
// RenameUser renames a User. Returns error if the User does
|
||||
// not exist or if another User exists with the new name.
|
||||
func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
|
||||
func RenameUser(tx *gorm.DB, oldName, newName string) error {
|
||||
var err error
|
||||
oldUser, err := hsdb.getUser(oldName)
|
||||
oldUser, err := GetUser(tx, oldName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -94,7 +103,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = hsdb.getUser(newName)
|
||||
_, err = GetUser(tx, newName)
|
||||
if err == nil {
|
||||
return ErrUserExists
|
||||
}
|
||||
@@ -104,24 +113,22 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
|
||||
|
||||
oldUser.Name = newName
|
||||
|
||||
if result := hsdb.db.Save(&oldUser); result.Error != nil {
|
||||
if result := tx.Save(&oldUser); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUser fetches a user by name.
|
||||
func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.getUser(name)
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
|
||||
return GetUser(rx, name)
|
||||
})
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getUser(name string) (*types.User, error) {
|
||||
func GetUser(tx *gorm.DB, name string) (*types.User, error) {
|
||||
user := types.User{}
|
||||
if result := hsdb.db.First(&user, "name = ?", name); errors.Is(
|
||||
if result := tx.First(&user, "name = ?", name); errors.Is(
|
||||
result.Error,
|
||||
gorm.ErrRecordNotFound,
|
||||
) {
|
||||
@@ -131,17 +138,16 @@ func (hsdb *HSDatabase) getUser(name string) (*types.User, error) {
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// ListUsers gets all the existing users.
|
||||
func (hsdb *HSDatabase) ListUsers() ([]types.User, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.listUsers()
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
|
||||
return ListUsers(rx)
|
||||
})
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) listUsers() ([]types.User, error) {
|
||||
// ListUsers gets all the existing users.
|
||||
func ListUsers(tx *gorm.DB) ([]types.User, error) {
|
||||
users := []types.User{}
|
||||
if err := hsdb.db.Find(&users).Error; err != nil {
|
||||
if err := tx.Find(&users).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -149,46 +155,42 @@ func (hsdb *HSDatabase) listUsers() ([]types.User, error) {
|
||||
}
|
||||
|
||||
// ListNodesByUser gets all the nodes in a given user.
|
||||
func (hsdb *HSDatabase) ListNodesByUser(name string) (types.Nodes, error) {
|
||||
hsdb.mu.RLock()
|
||||
defer hsdb.mu.RUnlock()
|
||||
|
||||
return hsdb.listNodesByUser(name)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) listNodesByUser(name string) (types.Nodes, error) {
|
||||
func ListNodesByUser(tx *gorm.DB, name string) (types.Nodes, error) {
|
||||
err := util.CheckForFQDNRules(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user, err := hsdb.getUser(name)
|
||||
user, err := GetUser(tx, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodes := types.Nodes{}
|
||||
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil {
|
||||
if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
// AssignNodeToUser assigns a Node to a user.
|
||||
func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, username string) error {
|
||||
hsdb.mu.Lock()
|
||||
defer hsdb.mu.Unlock()
|
||||
return hsdb.Write(func(tx *gorm.DB) error {
|
||||
return AssignNodeToUser(tx, node, username)
|
||||
})
|
||||
}
|
||||
|
||||
// AssignNodeToUser assigns a Node to a user.
|
||||
func AssignNodeToUser(tx *gorm.DB, node *types.Node, username string) error {
|
||||
err := util.CheckForFQDNRules(username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user, err := hsdb.getUser(username)
|
||||
user, err := GetUser(tx, username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
node.User = *user
|
||||
if result := hsdb.db.Save(&node); result.Error != nil {
|
||||
if result := tx.Save(&node); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user