diff --git a/db.go b/db.go index 72386ca8..41949525 100644 --- a/db.go +++ b/db.go @@ -106,6 +106,7 @@ func (h *Headscale) initDB() error { Err(err). Str("enabled_route", prefix.String()). Msg("Error parsing enabled_route") + continue } @@ -114,6 +115,7 @@ func (h *Headscale) initDB() error { log.Info(). Str("enabled_route", prefix.String()). Msg("Route already migrated to new table, skipping") + continue } @@ -335,6 +337,7 @@ func (i *IPPrefix) Scan(destination interface{}) error { return err } *i = IPPrefix(prefix) + return nil default: return fmt.Errorf("%w: unexpected data type %T", ErrCannotParsePrefix, destination) @@ -344,6 +347,7 @@ func (i *IPPrefix) Scan(destination interface{}) error { // Value return json value, implement driver.Valuer interface. func (i IPPrefix) Value() (driver.Value, error) { prefixStr := netip.Prefix(i).String() + return prefixStr, nil } diff --git a/machine.go b/machine.go index 0ac56d88..9be7204e 100644 --- a/machine.go +++ b/machine.go @@ -941,6 +941,7 @@ func (h *Headscale) GetAdvertisedRoutes(machine *Machine) ([]netip.Prefix, error Err(err). Str("machine", machine.Hostname). Msg("Could not get advertised routes for machine") + return nil, err } @@ -966,6 +967,7 @@ func (h *Headscale) GetEnabledRoutes(machine *Machine) ([]netip.Prefix, error) { Err(err). Str("machine", machine.Hostname). Msg("Could not get enabled routes for machine") + return nil, err } @@ -986,6 +988,7 @@ func (h *Headscale) IsRoutesEnabled(machine *Machine, routeStr string) bool { enabledRoutes, err := h.GetEnabledRoutes(machine) if err != nil { log.Error().Err(err).Msg("Could not get enabled routes") + return false } @@ -1106,9 +1109,9 @@ func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error { } } - for _, approvedRoute := range approvedRoutes { - approvedRoute.Enabled = true - err = h.db.Save(&approvedRoute).Error + for i, approvedRoute := range approvedRoutes { + approvedRoutes[i].Enabled = true + err = h.db.Save(&approvedRoutes[i]).Error if err != nil { log.Err(err). Str("approvedRoute", approvedRoute.String()). @@ -1122,25 +1125,6 @@ func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error { return nil } -func (h *Headscale) RoutesToProto(machine *Machine) *v1.Routes { - availableRoutes, err := h.GetAdvertisedRoutes(machine) - if err != nil { - log.Error().Err(err).Msg("Could not get advertised routes") - return nil - } - - enabledRoutes, err := h.GetEnabledRoutes(machine) - if err != nil { - log.Error().Err(err).Msg("Could not get enabled routes") - return nil - } - - return &v1.Routes{ - AdvertisedRoutes: ipPrefixToString(availableRoutes), - EnabledRoutes: ipPrefixToString(enabledRoutes), - } -} - func (h *Headscale) generateGivenName(suppliedName string, randomSuffix bool) (string, error) { normalizedHostname, err := NormalizeToFQDNRules( suppliedName, diff --git a/machine_test.go b/machine_test.go index c5073233..6cbf2821 100644 --- a/machine_test.go +++ b/machine_test.go @@ -1159,7 +1159,9 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { machine0ByID, err := app.GetMachineByID(0) c.Assert(err, check.IsNil) - app.EnableAutoApprovedRoutes(machine0ByID) + err = app.EnableAutoApprovedRoutes(machine0ByID) + c.Assert(err, check.IsNil) + enabledRoutes, err := app.GetEnabledRoutes(machine0ByID) c.Assert(err, check.IsNil) c.Assert(enabledRoutes, check.HasLen, 3) diff --git a/protocol_common_poll.go b/protocol_common_poll.go index 7df01c86..8732c707 100644 --- a/protocol_common_poll.go +++ b/protocol_common_poll.go @@ -53,7 +53,15 @@ func (h *Headscale) handlePollCommon( } // update routes with peer information - h.EnableAutoApprovedRoutes(machine) + err = h.EnableAutoApprovedRoutes(machine) + if err != nil { + log.Error(). + Caller(). + Bool("noise", isNoise). + Str("machine", machine.Hostname). + Err(err). + Msg("Error running auto approved routes") + } } // From Tailscale client: diff --git a/routes.go b/routes.go index 221db60f..f1b1913f 100644 --- a/routes.go +++ b/routes.go @@ -1,6 +1,7 @@ package headscale import ( + "errors" "fmt" "net/netip" @@ -44,10 +45,11 @@ func (rs Routes) toPrefixes() []netip.Prefix { for i, r := range rs { prefixes[i] = netip.Prefix(r.Prefix) } + return prefixes } -// isUniquePrefix returns if there is another machine providing the same route already +// isUniquePrefix returns if there is another machine providing the same route already. func (h *Headscale) isUniquePrefix(route Route) bool { var count int64 h.db. @@ -56,6 +58,7 @@ func (h *Headscale) isUniquePrefix(route Route) bool { route.Prefix, route.MachineID, true, true).Count(&count) + return count == 0 } @@ -65,11 +68,11 @@ func (h *Headscale) getPrimaryRoute(prefix netip.Prefix) (*Route, error) { Preload("Machine"). Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", IPPrefix(prefix), true, true, true). First(&route).Error - if err != nil && err != gorm.ErrRecordNotFound { + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return nil, err } - if err == gorm.ErrRecordNotFound { + if errors.Is(err, gorm.ErrRecordNotFound) { return nil, gorm.ErrRecordNotFound } @@ -77,7 +80,7 @@ func (h *Headscale) getPrimaryRoute(prefix netip.Prefix) (*Route, error) { } // getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover) -// Exit nodes are not considered for this, as they are never marked as Primary +// Exit nodes are not considered for this, as they are never marked as Primary. func (h *Headscale) getMachinePrimaryRoutes(m *Machine) ([]Route, error) { var routes []Route err := h.db. @@ -103,24 +106,22 @@ func (h *Headscale) processMachineRoutes(machine *Machine) error { advertisedRoutes[prefix] = false } - for _, route := range currentRoutes { + for pos, route := range currentRoutes { if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok { if !route.Advertised { - route.Advertised = true - err := h.db.Save(&route).Error + currentRoutes[pos].Advertised = true + err := h.db.Save(¤tRoutes[pos]).Error if err != nil { return err } } advertisedRoutes[netip.Prefix(route.Prefix)] = true - } else { - if route.Advertised { - route.Advertised = false - route.Enabled = false - err := h.db.Save(&route).Error - if err != nil { - return err - } + } else if route.Advertised { + currentRoutes[pos].Advertised = false + currentRoutes[pos].Enabled = false + err := h.db.Save(¤tRoutes[pos]).Error + if err != nil { + return err } } } @@ -150,25 +151,26 @@ func (h *Headscale) handlePrimarySubnetFailover() error { Preload("Machine"). Where("advertised = ? AND enabled = ?", true, true). Find(&routes).Error - if err != nil && err != gorm.ErrRecordNotFound { + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { log.Error().Err(err).Msg("error getting routes") } - for _, route := range routes { + for pos, route := range routes { if route.isExitRoute() { continue } if !route.IsPrimary { _, err := h.getPrimaryRoute(netip.Prefix(route.Prefix)) - if h.isUniquePrefix(route) || err == gorm.ErrRecordNotFound { - route.IsPrimary = true - err := h.db.Save(&route).Error + if h.isUniquePrefix(route) || errors.Is(err, gorm.ErrRecordNotFound) { + routes[pos].IsPrimary = true + err := h.db.Save(&routes[pos]).Error if err != nil { log.Error().Err(err).Msg("error marking route as primary") return err } + continue } } @@ -193,16 +195,17 @@ func (h *Headscale) handlePrimarySubnetFailover() error { route.MachineID, true, true). Find(&newPrimaryRoutes).Error - if err != nil && err != gorm.ErrRecordNotFound { + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { log.Error().Err(err).Msg("error finding new primary route") return err } var newPrimaryRoute *Route - for _, r := range newPrimaryRoutes { + for pos, r := range newPrimaryRoutes { if r.Machine.isOnline() { - newPrimaryRoute = &r + newPrimaryRoute = &newPrimaryRoutes[pos] + break } } @@ -212,6 +215,7 @@ func (h *Headscale) handlePrimarySubnetFailover() error { Str("machine", route.Machine.Hostname). Str("prefix", netip.Prefix(route.Prefix).String()). Msgf("no alternative primary route found") + continue } @@ -222,8 +226,8 @@ func (h *Headscale) handlePrimarySubnetFailover() error { Msgf("found new primary route") // disable the old primary route - route.IsPrimary = false - err = h.db.Save(&route).Error + routes[pos].IsPrimary = false + err = h.db.Save(&routes[pos]).Error if err != nil { log.Error().Err(err).Msg("error disabling old primary route")