From eff529f2c566a9b4ffe141af7fbbfee0afa40d7e Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 17 Jul 2023 13:35:05 +0200 Subject: [PATCH] introduce rw lock for db, ish... Signed-off-by: Kristoffer Dalby --- hscontrol/auth.go | 7 +- hscontrol/db/api_key.go | 21 +++ hscontrol/db/db.go | 2 + hscontrol/db/machine.go | 263 +++++++++++++++++++++++------------ hscontrol/db/machine_test.go | 37 +---- hscontrol/db/preauth_keys.go | 37 ++++- hscontrol/db/routes.go | 89 ++++++++++-- hscontrol/db/users.go | 53 +++++-- hscontrol/db/users_test.go | 6 +- hscontrol/grpcv1.go | 7 +- hscontrol/oidc.go | 2 +- hscontrol/poll.go | 1 + 12 files changed, 369 insertions(+), 156 deletions(-) diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 43dfd2b0..78db626b 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -309,7 +309,7 @@ func (h *Headscale) handleAuthKey( machine.NodeKey = nodeKey machine.AuthKeyID = uint(pak.ID) - err := h.db.RefreshMachine(machine, registerRequest.Expiry) + err := h.db.MachineSetExpiry(machine, registerRequest.Expiry) if err != nil { log.Error(). Caller(). @@ -510,7 +510,8 @@ func (h *Headscale) handleMachineLogOut( Str("machine", machine.Hostname). Msg("Client requested logout") - err := h.db.ExpireMachine(&machine) + now := time.Now() + err := h.db.MachineSetExpiry(&machine, now) if err != nil { log.Error(). Caller(). @@ -552,7 +553,7 @@ func (h *Headscale) handleMachineLogOut( } if machine.IsEphemeral() { - err = h.db.HardDeleteMachine(&machine) + err = h.db.DeleteMachine(&machine) if err != nil { log.Error(). Err(err). diff --git a/hscontrol/db/api_key.go b/hscontrol/db/api_key.go index 4e4030eb..bc8dc2bb 100644 --- a/hscontrol/db/api_key.go +++ b/hscontrol/db/api_key.go @@ -22,6 +22,9 @@ var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey") func (hsdb *HSDatabase) CreateAPIKey( expiration *time.Time, ) (string, *types.APIKey, error) { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength) if err != nil { return "", nil, err @@ -55,6 +58,9 @@ func (hsdb *HSDatabase) CreateAPIKey( // ListAPIKeys returns the list of ApiKeys for a user. func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + keys := []types.APIKey{} if err := hsdb.db.Find(&keys).Error; err != nil { return nil, err @@ -65,6 +71,9 @@ func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { // GetAPIKey returns a ApiKey for a given key. func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + key := types.APIKey{} if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil { return nil, result.Error @@ -75,6 +84,9 @@ func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { // GetAPIKeyByID returns a ApiKey for a given id. func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + key := types.APIKey{} if result := hsdb.db.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil { return nil, result.Error @@ -86,6 +98,9 @@ func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { // DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey // does not exist. func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + if result := hsdb.db.Unscoped().Delete(key); result.Error != nil { return result.Error } @@ -95,6 +110,9 @@ func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error { // ExpireAPIKey marks a ApiKey as expired. func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil { return err } @@ -103,6 +121,9 @@ func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { } func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + prefix, hash, found := strings.Cut(keyStr, ".") if !found { return false, ErrAPIKeyFailedToParse diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index ea6ce21f..19bf9425 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -40,6 +40,8 @@ type HSDatabase struct { db *gorm.DB notifier *notifier.Notifier + mu sync.RWMutex + ipAllocationMutex sync.Mutex ipPrefixes []netip.Prefix diff --git a/hscontrol/db/machine.go b/hscontrol/db/machine.go index 47dfaa12..033492e1 100644 --- a/hscontrol/db/machine.go +++ b/hscontrol/db/machine.go @@ -36,6 +36,13 @@ var ( // ListPeers returns all peers of machine, regardless of any Policy or if the node is expired. func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.listPeers(machine) +} + +func (hsdb *HSDatabase) listPeers(machine *types.Machine) (types.Machines, error) { log.Trace(). Caller(). Str("machine", machine.Hostname). @@ -63,6 +70,13 @@ func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error } func (hsdb *HSDatabase) ListMachines() ([]types.Machine, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.listMachines() +} + +func (hsdb *HSDatabase) listMachines() ([]types.Machine, error) { machines := []types.Machine{} if err := hsdb.db. Preload("AuthKey"). @@ -77,6 +91,13 @@ func (hsdb *HSDatabase) ListMachines() ([]types.Machine, error) { } func (hsdb *HSDatabase) ListMachinesByGivenName(givenName string) (types.Machines, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.listMachinesByGivenName(givenName) +} + +func (hsdb *HSDatabase) listMachinesByGivenName(givenName string) (types.Machines, error) { machines := types.Machines{} if err := hsdb.db. Preload("AuthKey"). @@ -92,6 +113,9 @@ func (hsdb *HSDatabase) ListMachinesByGivenName(givenName string) (types.Machine // GetMachine finds a Machine by name and user and returns the Machine struct. func (hsdb *HSDatabase) GetMachine(user string, name string) (*types.Machine, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + machines, err := hsdb.ListMachinesByUser(user) if err != nil { return nil, err @@ -111,15 +135,17 @@ func (hsdb *HSDatabase) GetMachineByGivenName( user string, givenName string, ) (*types.Machine, error) { - machines, err := hsdb.ListMachinesByUser(user) - if err != nil { - return nil, err - } + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() - for _, m := range machines { - if m.GivenName == givenName { - return &m, nil - } + machine := types.Machine{} + if err := hsdb.db. + Preload("AuthKey"). + Preload("AuthKey.User"). + Preload("User"). + Preload("Routes"). + Where("given_name = ?", givenName).First(&machine).Error; err != nil { + return nil, err } return nil, ErrMachineNotFound @@ -127,6 +153,9 @@ func (hsdb *HSDatabase) GetMachineByGivenName( // GetMachineByID finds a Machine by ID and returns the Machine struct. func (hsdb *HSDatabase) GetMachineByID(id uint64) (*types.Machine, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + mach := types.Machine{} if result := hsdb.db. Preload("AuthKey"). @@ -144,6 +173,9 @@ func (hsdb *HSDatabase) GetMachineByID(id uint64) (*types.Machine, error) { func (hsdb *HSDatabase) GetMachineByMachineKey( machineKey key.MachinePublic, ) (*types.Machine, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + mach := types.Machine{} if result := hsdb.db. Preload("AuthKey"). @@ -161,6 +193,9 @@ func (hsdb *HSDatabase) GetMachineByMachineKey( func (hsdb *HSDatabase) GetMachineByNodeKey( nodeKey key.NodePublic, ) (*types.Machine, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + machine := types.Machine{} if result := hsdb.db. Preload("AuthKey"). @@ -179,6 +214,9 @@ func (hsdb *HSDatabase) GetMachineByNodeKey( func (hsdb *HSDatabase) GetMachineByAnyKey( machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic, ) (*types.Machine, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + machine := types.Machine{} if result := hsdb.db. Preload("AuthKey"). @@ -195,10 +233,10 @@ func (hsdb *HSDatabase) GetMachineByAnyKey( return &machine, nil } -// TODO(kradalby): rename this, it sounds like a mix of getting and setting to db -// UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database -// and updates it with the latest data from the database. -func (hsdb *HSDatabase) UpdateMachineFromDatabase(machine *types.Machine) error { +func (hsdb *HSDatabase) MachineReloadFromDatabase(machine *types.Machine) error { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + if result := hsdb.db.Find(machine).First(&machine); result.Error != nil { return result.Error } @@ -211,46 +249,36 @@ func (hsdb *HSDatabase) SetTags( machine *types.Machine, tags []string, ) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + newTags := []string{} for _, tag := range tags { if !util.StringOrPrefixListContains(newTags, tag) { newTags = append(newTags, tag) } } - machine.ForcedTags = newTags - hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ - Type: types.StatePeerChanged, - Changed: []uint64{machine.ID}, - }, machine.MachineKey) - - if err := hsdb.db.Save(machine).Error; err != nil { + if err := hsdb.db.Model(machine).Updates(types.Machine{ + ForcedTags: newTags, + }).Error; err != nil { return fmt.Errorf("failed to update tags for machine in the database: %w", err) } - return nil -} - -// ExpireMachine takes a Machine struct and sets the expire field to now. -func (hsdb *HSDatabase) ExpireMachine(machine *types.Machine) error { - now := time.Now() - machine.Expiry = &now - hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ Type: types.StatePeerChanged, Changed: []uint64{machine.ID}, }, machine.MachineKey) - if err := hsdb.db.Save(machine).Error; err != nil { - return fmt.Errorf("failed to expire machine in the database: %w", err) - } - return nil } // RenameMachine takes a Machine struct and a new GivenName for the machines // and renames it. func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + err := util.CheckForFQDNRules( newName, ) @@ -260,82 +288,93 @@ func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) er Str("func", "RenameMachine"). Str("machine", machine.Hostname). Str("newName", newName). - Err(err) + Err(err). + Msg("failed to rename machine") return err } machine.GivenName = newName + if err := hsdb.db.Model(machine).Updates(types.Machine{ + GivenName: newName, + }).Error; err != nil { + return fmt.Errorf("failed to rename machine in the database: %w", err) + } + hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ Type: types.StatePeerChanged, Changed: []uint64{machine.ID}, }, machine.MachineKey) - if err := hsdb.db.Save(machine).Error; err != nil { - return fmt.Errorf("failed to rename machine in the database: %w", err) - } - return nil } -// RefreshMachine takes a Machine struct and a new expiry time. -func (hsdb *HSDatabase) RefreshMachine(machine *types.Machine, expiry time.Time) error { +// MachineSetExpiry takes a Machine struct and a new expiry time. +func (hsdb *HSDatabase) MachineSetExpiry(machine *types.Machine, expiry time.Time) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + return hsdb.machineSetExpiry(machine, expiry) +} + +func (hsdb *HSDatabase) machineSetExpiry(machine *types.Machine, expiry time.Time) error { now := time.Now() - machine.LastSuccessfulUpdate = &now - machine.Expiry = &expiry - - hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ - Type: types.StatePeerChanged, - Changed: []uint64{machine.ID}, - }, machine.MachineKey) - - if err := hsdb.db.Save(machine).Error; err != nil { + if err := hsdb.db.Model(machine).Updates(types.Machine{ + LastSuccessfulUpdate: &now, + Expiry: &expiry, + }).Error; err != nil { return fmt.Errorf( "failed to refresh machine (update expiration) in the database: %w", err, ) } + hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ + Type: types.StatePeerChanged, + Changed: []uint64{machine.ID}, + }, machine.MachineKey) + return nil } -// DeleteMachine softs deletes a Machine from the database. +// DeleteMachine deletes a Machine from the database. func (hsdb *HSDatabase) DeleteMachine(machine *types.Machine) error { - err := hsdb.DeleteMachineRoutes(machine) + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + return hsdb.deleteMachine(machine) +} + +func (hsdb *HSDatabase) deleteMachine(machine *types.Machine) error { + err := hsdb.deleteMachineRoutes(machine) if err != nil { return err } - if err := hsdb.db.Delete(&machine).Error; err != nil { + // Unscoped causes the machine to be fully removed from the database. + if err := hsdb.db.Unscoped().Delete(&machine).Error; err != nil { return err } + hsdb.notifier.NotifyAll(types.StateUpdate{ + Type: types.StatePeerRemoved, + Removed: []tailcfg.NodeID{tailcfg.NodeID(machine.ID)}, + }) + return nil } func (hsdb *HSDatabase) TouchMachine(machine *types.Machine) error { - return hsdb.db.Updates(types.Machine{ - ID: machine.ID, + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + return hsdb.db.Model(machine).Updates(types.Machine{ LastSeen: machine.LastSeen, LastSuccessfulUpdate: machine.LastSuccessfulUpdate, }).Error } -// HardDeleteMachine hard deletes a Machine from the database. -func (hsdb *HSDatabase) HardDeleteMachine(machine *types.Machine) error { - err := hsdb.DeleteMachineRoutes(machine) - if err != nil { - return err - } - - if err := hsdb.db.Unscoped().Delete(&machine).Error; err != nil { - return err - } - - return nil -} - func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( cache *cache.Cache, nodeKeyStr string, @@ -343,6 +382,9 @@ func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( machineExpiry *time.Time, registrationMethod string, ) (*types.Machine, error) { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + nodeKey := key.NodePublic{} err := nodeKey.UnmarshalText([]byte(nodeKeyStr)) if err != nil { @@ -358,7 +400,7 @@ func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( if machineInterface, ok := cache.Get(util.NodePublicKeyStripPrefix(nodeKey)); ok { if registrationMachine, ok := machineInterface.(types.Machine); ok { - user, err := hsdb.GetUser(userName) + user, err := hsdb.getUser(userName) if err != nil { return nil, fmt.Errorf( "failed to find user in register machine from auth callback, %w", @@ -379,7 +421,7 @@ func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( registrationMachine.Expiry = machineExpiry } - machine, err := hsdb.RegisterMachine( + machine, err := hsdb.registerMachine( registrationMachine, ) @@ -397,8 +439,14 @@ func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( } // RegisterMachine is executed from the CLI to register a new Machine using its MachineKey. -func (hsdb *HSDatabase) RegisterMachine(machine types.Machine, -) (*types.Machine, error) { +func (hsdb *HSDatabase) RegisterMachine(machine types.Machine) (*types.Machine, error) { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + return hsdb.registerMachine(machine) +} + +func (hsdb *HSDatabase) registerMachine(machine types.Machine) (*types.Machine, error) { log.Debug(). Str("machine", machine.Hostname). Str("machine_key", machine.MachineKey). @@ -456,9 +504,12 @@ func (hsdb *HSDatabase) RegisterMachine(machine types.Machine, // MachineSetNodeKey sets the node key of a machine and saves it to the database. func (hsdb *HSDatabase) MachineSetNodeKey(machine *types.Machine, nodeKey key.NodePublic) error { - machine.NodeKey = util.NodePublicKeyStripPrefix(nodeKey) + hsdb.mu.Lock() + defer hsdb.mu.Unlock() - if err := hsdb.db.Save(machine).Error; err != nil { + if err := hsdb.db.Model(machine).Updates(types.Machine{ + NodeKey: util.NodePublicKeyStripPrefix(nodeKey), + }).Error; err != nil { return err } @@ -468,11 +519,14 @@ func (hsdb *HSDatabase) MachineSetNodeKey(machine *types.Machine, nodeKey key.No // MachineSetMachineKey sets the machine key of a machine and saves it to the database. func (hsdb *HSDatabase) MachineSetMachineKey( machine *types.Machine, - nodeKey key.MachinePublic, + machineKey key.MachinePublic, ) error { - machine.MachineKey = util.MachinePublicKeyStripPrefix(nodeKey) + hsdb.mu.Lock() + defer hsdb.mu.Unlock() - if err := hsdb.db.Save(machine).Error; err != nil { + if err := hsdb.db.Model(machine).Updates(types.Machine{ + MachineKey: util.MachinePublicKeyStripPrefix(machineKey), + }).Error; err != nil { return err } @@ -482,6 +536,9 @@ func (hsdb *HSDatabase) MachineSetMachineKey( // MachineSave saves a machine object to the database, prefer to use a specific save method rather // than this. It is intended to be used when we are changing or. func (hsdb *HSDatabase) MachineSave(machine *types.Machine) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + if err := hsdb.db.Save(machine).Error; err != nil { return err } @@ -491,6 +548,13 @@ func (hsdb *HSDatabase) MachineSave(machine *types.Machine) error { // GetAdvertisedRoutes returns the routes that are be advertised by the given machine. func (hsdb *HSDatabase) GetAdvertisedRoutes(machine *types.Machine) ([]netip.Prefix, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.getAdvertisedRoutes(machine) +} + +func (hsdb *HSDatabase) getAdvertisedRoutes(machine *types.Machine) ([]netip.Prefix, error) { routes := types.Routes{} err := hsdb.db. @@ -516,6 +580,13 @@ func (hsdb *HSDatabase) GetAdvertisedRoutes(machine *types.Machine) ([]netip.Pre // GetEnabledRoutes returns the routes that are enabled for the machine. func (hsdb *HSDatabase) GetEnabledRoutes(machine *types.Machine) ([]netip.Prefix, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.getEnabledRoutes(machine) +} + +func (hsdb *HSDatabase) getEnabledRoutes(machine *types.Machine) ([]netip.Prefix, error) { routes := types.Routes{} err := hsdb.db. @@ -541,12 +612,15 @@ func (hsdb *HSDatabase) GetEnabledRoutes(machine *types.Machine) ([]netip.Prefix } func (hsdb *HSDatabase) IsRoutesEnabled(machine *types.Machine, routeStr string) bool { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + route, err := netip.ParsePrefix(routeStr) if err != nil { return false } - enabledRoutes, err := hsdb.GetEnabledRoutes(machine) + enabledRoutes, err := hsdb.getEnabledRoutes(machine) if err != nil { log.Error().Err(err).Msg("Could not get enabled routes") @@ -575,7 +649,10 @@ func OnlineMachineMap(peers types.Machines) map[tailcfg.NodeID]bool { func (hsdb *HSDatabase) ListOnlineMachines( machine *types.Machine, ) (map[tailcfg.NodeID]bool, error) { - peers, err := hsdb.ListPeers(machine) + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + peers, err := hsdb.listPeers(machine) if err != nil { return nil, err } @@ -595,7 +672,7 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string newRoutes[index] = route } - advertisedRoutes, err := hsdb.GetAdvertisedRoutes(machine) + advertisedRoutes, err := hsdb.getAdvertisedRoutes(machine) if err != nil { return err } @@ -642,7 +719,7 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string return nil } -func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool) (string, error) { +func generateGivenName(suppliedName string, randomSuffix bool) (string, error) { normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper( suppliedName, ) @@ -669,20 +746,23 @@ func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool } func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string) (string, error) { - givenName, err := hsdb.generateGivenName(suppliedName, false) + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + givenName, err := generateGivenName(suppliedName, false) if err != nil { return "", err } // Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/ - machines, err := hsdb.ListMachinesByGivenName(givenName) + machines, err := hsdb.listMachinesByGivenName(givenName) if err != nil { return "", err } for _, machine := range machines { if machine.MachineKey != machineKey && machine.GivenName == givenName { - postfixedName, err := hsdb.generateGivenName(suppliedName, true) + postfixedName, err := generateGivenName(suppliedName, true) if err != nil { return "", err } @@ -695,7 +775,10 @@ func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string } func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Duration) { - users, err := hsdb.ListUsers() + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + users, err := hsdb.listUsers() if err != nil { log.Error().Err(err).Msg("Error listing users") @@ -703,7 +786,7 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati } for _, user := range users { - machines, err := hsdb.ListMachinesByUser(user.Name) + machines, err := hsdb.listMachinesByUser(user.Name) if err != nil { log.Error(). Err(err). @@ -724,7 +807,7 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati Str("machine", machine.Hostname). Msg("Ephemeral client removed from database") - err = hsdb.HardDeleteMachine(&machines[idx]) + err = hsdb.deleteMachine(&machines[idx]) if err != nil { log.Error(). Err(err). @@ -744,12 +827,15 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati } func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + // use the time of the start of the function to ensure we // dont miss some machines by returning it _after_ we have // checked everything. started := time.Now() - users, err := hsdb.ListUsers() + users, err := hsdb.listUsers() if err != nil { log.Error().Err(err).Msg("Error listing users") @@ -757,7 +843,7 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time { } for _, user := range users { - machines, err := hsdb.ListMachinesByUser(user.Name) + machines, err := hsdb.listMachinesByUser(user.Name) if err != nil { log.Error(). Err(err). @@ -773,7 +859,8 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time { machine.Expiry.After(lastCheck) { expired = append(expired, tailcfg.NodeID(machine.ID)) - err := hsdb.ExpireMachine(&machines[index]) + now := time.Now() + err := hsdb.machineSetExpiry(&machines[index], now) if err != nil { log.Error(). Err(err). diff --git a/hscontrol/db/machine_test.go b/hscontrol/db/machine_test.go index 0220bb81..7f837e06 100644 --- a/hscontrol/db/machine_test.go +++ b/hscontrol/db/machine_test.go @@ -127,28 +127,6 @@ func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) { c.Assert(err, check.IsNil) } -func (s *Suite) TestDeleteMachine(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - machine := types.Machine{ - ID: 0, - MachineKey: "foo", - NodeKey: "bar", - DiscoKey: "faa", - Hostname: "testmachine", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(1), - } - db.db.Save(&machine) - - err = db.DeleteMachine(&machine) - c.Assert(err, check.IsNil) - - _, err = db.GetMachine(user.Name, "testmachine") - c.Assert(err, check.NotNil) -} - func (s *Suite) TestHardDeleteMachine(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) @@ -164,7 +142,7 @@ func (s *Suite) TestHardDeleteMachine(c *check.C) { } db.db.Save(&machine) - err = db.HardDeleteMachine(&machine) + err = db.DeleteMachine(&machine) c.Assert(err, check.IsNil) _, err = db.GetMachine(user.Name, "testmachine3") @@ -329,7 +307,8 @@ func (s *Suite) TestExpireMachine(c *check.C) { c.Assert(machineFromDB.IsExpired(), check.Equals, false) - err = db.ExpireMachine(machineFromDB) + now := time.Now() + err = db.MachineSetExpiry(machineFromDB, now) c.Assert(err, check.IsNil) c.Assert(machineFromDB.IsExpired(), check.Equals, true) @@ -450,14 +429,12 @@ func TestHeadscale_generateGivenName(t *testing.T) { } tests := []struct { name string - db *HSDatabase args args want *regexp.Regexp wantErr bool }{ { name: "simple machine name generation", - db: &HSDatabase{}, args: args{ suppliedName: "testmachine", randomSuffix: false, @@ -467,7 +444,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 53 chars", - db: &HSDatabase{}, args: args{ suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine", randomSuffix: false, @@ -477,7 +453,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 63 chars", - db: &HSDatabase{}, args: args{ suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", randomSuffix: false, @@ -487,7 +462,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 64 chars", - db: &HSDatabase{}, args: args{ suppliedName: "machineeee123456789012345678901234567890123456789012345678901234", randomSuffix: false, @@ -497,7 +471,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 73 chars", - db: &HSDatabase{}, args: args{ suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123", randomSuffix: false, @@ -507,7 +480,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with random suffix", - db: &HSDatabase{}, args: args{ suppliedName: "test", randomSuffix: true, @@ -517,7 +489,6 @@ func TestHeadscale_generateGivenName(t *testing.T) { }, { name: "machine name with 63 chars with random suffix", - db: &HSDatabase{}, args: args{ suppliedName: "machineeee12345678901234567890123456789012345678901234567890123", randomSuffix: true, @@ -528,7 +499,7 @@ func TestHeadscale_generateGivenName(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tt.db.generateGivenName(tt.args.suppliedName, tt.args.randomSuffix) + got, err := generateGivenName(tt.args.suppliedName, tt.args.randomSuffix) if (err != nil) != tt.wantErr { t.Errorf( "Headscale.GenerateGivenName() error = %v, wantErr %v", diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index abb79c34..ec7ab232 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -28,6 +28,10 @@ func (hsdb *HSDatabase) CreatePreAuthKey( expiration *time.Time, aclTags []string, ) (*types.PreAuthKey, error) { + // TODO(kradalby): figure out this lock + // hsdb.mu.Lock() + // defer hsdb.mu.Unlock() + user, err := hsdb.GetUser(userName) if err != nil { return nil, err @@ -92,7 +96,14 @@ func (hsdb *HSDatabase) CreatePreAuthKey( // ListPreAuthKeys returns the list of PreAuthKeys for a user. func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) { - user, err := hsdb.GetUser(userName) + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.listPreAuthKeys(userName) +} + +func (hsdb *HSDatabase) listPreAuthKeys(userName string) ([]types.PreAuthKey, error) { + user, err := hsdb.getUser(userName) if err != nil { return nil, err } @@ -107,6 +118,9 @@ func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, er // GetPreAuthKey returns a PreAuthKey for a given key. func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKey, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + pak, err := hsdb.ValidatePreAuthKey(key) if err != nil { return nil, err @@ -122,6 +136,13 @@ func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKe // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey // does not exist. func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + return hsdb.destroyPreAuthKey(pak) +} + +func (hsdb *HSDatabase) destroyPreAuthKey(pak types.PreAuthKey) error { return hsdb.db.Transaction(func(db *gorm.DB) error { if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil { return result.Error @@ -137,6 +158,9 @@ func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error { // MarkExpirePreAuthKey marks a PreAuthKey as expired. func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil { return err } @@ -146,6 +170,9 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { // UsePreAuthKey marks a PreAuthKey as used. func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + k.Used = true if err := hsdb.db.Save(k).Error; err != nil { return fmt.Errorf("failed to update key used status in the database: %w", err) @@ -157,6 +184,9 @@ func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error { // ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node // If returns no error and a PreAuthKey, it can be used. func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + pak := types.PreAuthKey{} if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is( result.Error, @@ -174,7 +204,10 @@ func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) } machines := types.Machines{} - if err := hsdb.db.Preload("AuthKey").Where(&types.Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil { + if err := hsdb.db. + Preload("AuthKey"). + Where(&types.Machine{AuthKeyID: uint(pak.ID)}). + Find(&machines).Error; err != nil { return nil, err } diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index 90ec3b1d..26a08f37 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -13,6 +13,13 @@ import ( var ErrRouteIsNotAvailable = errors.New("route is not available") func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.getRoutes() +} + +func (hsdb *HSDatabase) getRoutes() (types.Routes, error) { var routes types.Routes err := hsdb.db.Preload("Machine").Find(&routes).Error if err != nil { @@ -23,6 +30,13 @@ func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) { } func (hsdb *HSDatabase) GetMachineAdvertisedRoutes(machine *types.Machine) (types.Routes, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.getMachineAdvertisedRoutes(machine) +} + +func (hsdb *HSDatabase) getMachineAdvertisedRoutes(machine *types.Machine) (types.Routes, error) { var routes types.Routes err := hsdb.db. Preload("Machine"). @@ -36,6 +50,13 @@ func (hsdb *HSDatabase) GetMachineAdvertisedRoutes(machine *types.Machine) (type } func (hsdb *HSDatabase) GetMachineRoutes(m *types.Machine) (types.Routes, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.getMachineRoutes(m) +} + +func (hsdb *HSDatabase) getMachineRoutes(m *types.Machine) (types.Routes, error) { var routes types.Routes err := hsdb.db. Preload("Machine"). @@ -49,6 +70,13 @@ func (hsdb *HSDatabase) GetMachineRoutes(m *types.Machine) (types.Routes, error) } func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.getRoute(id) +} + +func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) { var route types.Route err := hsdb.db.Preload("Machine").First(&route, id).Error if err != nil { @@ -59,7 +87,14 @@ func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) { } func (hsdb *HSDatabase) EnableRoute(id uint64) error { - route, err := hsdb.GetRoute(id) + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + return hsdb.enableRoute(id) +} + +func (hsdb *HSDatabase) enableRoute(id uint64) error { + route, err := hsdb.getRoute(id) if err != nil { return err } @@ -79,7 +114,10 @@ func (hsdb *HSDatabase) EnableRoute(id uint64) error { } func (hsdb *HSDatabase) DisableRoute(id uint64) error { - route, err := hsdb.GetRoute(id) + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + route, err := hsdb.getRoute(id) if err != nil { return err } @@ -95,10 +133,10 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { return err } - return hsdb.HandlePrimarySubnetFailover() + return hsdb.handlePrimarySubnetFailover() } - routes, err := hsdb.GetMachineRoutes(&route.Machine) + routes, err := hsdb.getMachineRoutes(&route.Machine) if err != nil { return err } @@ -114,11 +152,14 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { } } - return hsdb.HandlePrimarySubnetFailover() + return hsdb.handlePrimarySubnetFailover() } func (hsdb *HSDatabase) DeleteRoute(id uint64) error { - route, err := hsdb.GetRoute(id) + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + route, err := hsdb.getRoute(id) if err != nil { return err } @@ -131,10 +172,10 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { return err } - return hsdb.HandlePrimarySubnetFailover() + return hsdb.handlePrimarySubnetFailover() } - routes, err := hsdb.GetMachineRoutes(&route.Machine) + routes, err := hsdb.getMachineRoutes(&route.Machine) if err != nil { return err } @@ -150,11 +191,11 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { return err } - return hsdb.HandlePrimarySubnetFailover() + return hsdb.handlePrimarySubnetFailover() } -func (hsdb *HSDatabase) DeleteMachineRoutes(m *types.Machine) error { - routes, err := hsdb.GetMachineRoutes(m) +func (hsdb *HSDatabase) deleteMachineRoutes(m *types.Machine) error { + routes, err := hsdb.getMachineRoutes(m) if err != nil { return err } @@ -165,7 +206,7 @@ func (hsdb *HSDatabase) DeleteMachineRoutes(m *types.Machine) error { } } - return hsdb.HandlePrimarySubnetFailover() + return hsdb.handlePrimarySubnetFailover() } // isUniquePrefix returns if there is another machine providing the same route already. @@ -201,6 +242,9 @@ func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, erro // getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover) // Exit nodes are not considered for this, as they are never marked as Primary. func (hsdb *HSDatabase) GetMachinePrimaryRoutes(m *types.Machine) (types.Routes, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + var routes types.Routes err := hsdb.db. Preload("Machine"). @@ -214,6 +258,13 @@ func (hsdb *HSDatabase) GetMachinePrimaryRoutes(m *types.Machine) (types.Routes, } func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + return hsdb.processMachineRoutes(machine) +} + +func (hsdb *HSDatabase) processMachineRoutes(machine *types.Machine) error { currentRoutes := types.Routes{} err := hsdb.db.Where("machine_id = ?", machine.ID).Find(¤tRoutes).Error if err != nil { @@ -264,6 +315,13 @@ func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error { } func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + return hsdb.handlePrimarySubnetFailover() +} + +func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { // first, get all the enabled routes var routes types.Routes err := hsdb.db. @@ -388,11 +446,14 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( aclPolicy *policy.ACLPolicy, machine *types.Machine, ) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + if len(machine.IPAddresses) == 0 { return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs } - routes, err := hsdb.GetMachineAdvertisedRoutes(machine) + routes, err := hsdb.getMachineAdvertisedRoutes(machine) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { log.Error(). Caller(). @@ -445,7 +506,7 @@ func (hsdb *HSDatabase) EnableAutoApprovedRoutes( } for _, approvedRoute := range approvedRoutes { - err := hsdb.EnableRoute(uint64(approvedRoute.ID)) + err := hsdb.enableRoute(uint64(approvedRoute.ID)) if err != nil { log.Err(err). Str("approvedRoute", approvedRoute.String()). diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index ce186751..5af4660b 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -18,6 +18,9 @@ var ( // CreateUser creates a new User. Returns error if could not be created // or another user already exists. func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + err := util.CheckForFQDNRules(name) if err != nil { return nil, err @@ -42,12 +45,15 @@ func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { // DestroyUser destroys a User. Returns error if the User does // not exist or if there are machines associated with it. func (hsdb *HSDatabase) DestroyUser(name string) error { - user, err := hsdb.GetUser(name) + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + + user, err := hsdb.getUser(name) if err != nil { return ErrUserNotFound } - machines, err := hsdb.ListMachinesByUser(name) + machines, err := hsdb.listMachinesByUser(name) if err != nil { return err } @@ -55,12 +61,12 @@ func (hsdb *HSDatabase) DestroyUser(name string) error { return ErrUserStillHasNodes } - keys, err := hsdb.ListPreAuthKeys(name) + keys, err := hsdb.listPreAuthKeys(name) if err != nil { return err } for _, key := range keys { - err = hsdb.DestroyPreAuthKey(key) + err = hsdb.destroyPreAuthKey(key) if err != nil { return err } @@ -76,8 +82,11 @@ func (hsdb *HSDatabase) DestroyUser(name string) error { // RenameUser renames a User. Returns error if the User does // not exist or if another User exists with the new name. func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + var err error - oldUser, err := hsdb.GetUser(oldName) + oldUser, err := hsdb.getUser(oldName) if err != nil { return err } @@ -85,7 +94,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { if err != nil { return err } - _, err = hsdb.GetUser(newName) + _, err = hsdb.getUser(newName) if err == nil { return ErrUserExists } @@ -104,6 +113,13 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { // GetUser fetches a user by name. func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.getUser(name) +} + +func (hsdb *HSDatabase) getUser(name string) (*types.User, error) { user := types.User{} if result := hsdb.db.First(&user, "name = ?", name); errors.Is( result.Error, @@ -117,6 +133,13 @@ func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) { // ListUsers gets all the existing users. func (hsdb *HSDatabase) ListUsers() ([]types.User, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.listUsers() +} + +func (hsdb *HSDatabase) listUsers() ([]types.User, error) { users := []types.User{} if err := hsdb.db.Find(&users).Error; err != nil { return nil, err @@ -127,11 +150,18 @@ func (hsdb *HSDatabase) ListUsers() ([]types.User, error) { // ListMachinesByUser gets all the nodes in a given user. func (hsdb *HSDatabase) ListMachinesByUser(name string) (types.Machines, error) { + hsdb.mu.RLock() + defer hsdb.mu.RUnlock() + + return hsdb.listMachinesByUser(name) +} + +func (hsdb *HSDatabase) listMachinesByUser(name string) (types.Machines, error) { err := util.CheckForFQDNRules(name) if err != nil { return nil, err } - user, err := hsdb.GetUser(name) + user, err := hsdb.getUser(name) if err != nil { return nil, err } @@ -144,13 +174,16 @@ func (hsdb *HSDatabase) ListMachinesByUser(name string) (types.Machines, error) return machines, nil } -// SetMachineUser assigns a Machine to a user. -func (hsdb *HSDatabase) SetMachineUser(machine *types.Machine, username string) error { +// AssignMachineToUser assigns a Machine to a user. +func (hsdb *HSDatabase) AssignMachineToUser(machine *types.Machine, username string) error { + hsdb.mu.Lock() + defer hsdb.mu.Unlock() + err := util.CheckForFQDNRules(username) if err != nil { return err } - user, err := hsdb.GetUser(username) + user, err := hsdb.getUser(username) if err != nil { return err } diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index bc468b23..97b3e6d7 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -114,15 +114,15 @@ func (s *Suite) TestSetMachineUser(c *check.C) { db.db.Save(&machine) c.Assert(machine.UserID, check.Equals, oldUser.ID) - err = db.SetMachineUser(&machine, newUser.Name) + err = db.AssignMachineToUser(&machine, newUser.Name) c.Assert(err, check.IsNil) c.Assert(machine.UserID, check.Equals, newUser.ID) c.Assert(machine.User.Name, check.Equals, newUser.Name) - err = db.SetMachineUser(&machine, "non-existing-user") + err = db.AssignMachineToUser(&machine, "non-existing-user") c.Assert(err, check.Equals, ErrUserNotFound) - err = db.SetMachineUser(&machine, newUser.Name) + err = db.AssignMachineToUser(&machine, newUser.Name) c.Assert(err, check.IsNil) c.Assert(machine.UserID, check.Equals, newUser.ID) c.Assert(machine.User.Name, check.Equals, newUser.Name) diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 74950c20..292c8d84 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -275,8 +275,11 @@ func (api headscaleV1APIServer) ExpireMachine( return nil, err } - api.h.db.ExpireMachine( + now := time.Now() + + api.h.db.MachineSetExpiry( machine, + now, ) log.Trace(). @@ -358,7 +361,7 @@ func (api headscaleV1APIServer) MoveMachine( return nil, err } - err = api.h.db.SetMachineUser(machine, request.GetUser()) + err = api.h.db.AssignMachineToUser(machine, request.GetUser()) if err != nil { return nil, err } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 66383838..010bcb15 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -523,7 +523,7 @@ func (h *Headscale) validateMachineForOIDCCallback( Str("machine", machine.Hostname). Msg("machine already registered, reauthenticating") - err := h.db.RefreshMachine(machine, expiry) + err := h.db.MachineSetExpiry(machine, expiry) if err != nil { util.LogErr(err, "Failed to refresh machine") http.Error( diff --git a/hscontrol/poll.go b/hscontrol/poll.go index bf7a0f49..77161fce 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -107,6 +107,7 @@ func (h *Headscale) handlePoll( machine.LastSeen = &now } + // TODO(kradalby): Save specific stuff, not whole object. if err := h.db.MachineSave(machine); err != nil { logErr(err, "Failed to persist/update machine in the database") http.Error(writer, "", http.StatusInternalServerError)