diff --git a/api_common.go b/api_common.go index c4f9c798..75cc1a5f 100644 --- a/api_common.go +++ b/api_common.go @@ -13,7 +13,7 @@ func (h *Headscale) generateMapResponse( Str("func", "generateMapResponse"). Str("machine", mapRequest.Hostinfo.Hostname). Msg("Creating Map response") - node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig) + node, err := h.toNode(*machine, h.cfg.BaseDomain, h.cfg.DNSConfig) if err != nil { log.Error(). Caller(). @@ -37,7 +37,7 @@ func (h *Headscale) generateMapResponse( profiles := h.getMapResponseUserProfiles(*machine, peers) - nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig) + nodePeers, err := h.toNodes(peers, h.cfg.BaseDomain, h.cfg.DNSConfig) if err != nil { log.Error(). Caller(). diff --git a/grpcv1.go b/grpcv1.go index 25ee7777..998b9902 100644 --- a/grpcv1.go +++ b/grpcv1.go @@ -374,7 +374,7 @@ func (api headscaleV1APIServer) GetMachineRoute( } return &v1.GetMachineRouteResponse{ - Routes: machine.RoutesToProto(), + Routes: api.h.RoutesToProto(machine), }, nil } @@ -393,7 +393,7 @@ func (api headscaleV1APIServer) EnableMachineRoutes( } return &v1.EnableMachineRoutesResponse{ - Routes: machine.RoutesToProto(), + Routes: api.h.RoutesToProto(machine), }, nil } diff --git a/machine.go b/machine.go index e9dbb606..cc15248f 100644 --- a/machine.go +++ b/machine.go @@ -13,6 +13,7 @@ import ( v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/rs/zerolog/log" "google.golang.org/protobuf/types/known/timestamppb" + "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" ) @@ -37,11 +38,6 @@ const ( maxHostnameLength = 255 ) -var ( - ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0") - ExitRouteV6 = netip.MustParsePrefix("::/0") -) - // Machine is a Headscale client. type Machine struct { ID uint64 `gorm:"primary_key"` @@ -76,9 +72,8 @@ type Machine struct { LastSuccessfulUpdate *time.Time Expiry *time.Time - HostInfo HostInfo - Endpoints StringList - EnabledRoutes IPPrefixes + HostInfo HostInfo + Endpoints StringList CreatedAt time.Time UpdatedAt time.Time @@ -595,14 +590,15 @@ func (machines MachinesP) String() string { return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) } -func (machines Machines) toNodes( +func (h *Headscale) toNodes( + machines Machines, baseDomain string, dnsConfig *tailcfg.DNSConfig, ) ([]*tailcfg.Node, error) { nodes := make([]*tailcfg.Node, len(machines)) for index, machine := range machines { - node, err := machine.toNode(baseDomain, dnsConfig) + node, err := h.toNode(machine, baseDomain, dnsConfig) if err != nil { return nil, err } @@ -615,7 +611,8 @@ func (machines Machines) toNodes( // toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes // as per the expected behaviour in the official SaaS. -func (machine Machine) toNode( +func (h *Headscale) toNode( + machine Machine, baseDomain string, dnsConfig *tailcfg.DNSConfig, ) (*tailcfg.Node, error) { @@ -663,24 +660,19 @@ func (machine Machine) toNode( []netip.Prefix{}, addrs...) // we append the node own IP, as it is required by the clients - allowedIPs = append(allowedIPs, machine.EnabledRoutes...) - - // TODO(kradalby): This is kind of a hack where we say that - // all the announced routes (except exit), is presented as primary - // routes. This might be problematic if two nodes expose the same route. - // This was added to address an issue where subnet routers stopped working - // when we only populated AllowedIPs. - primaryRoutes := []netip.Prefix{} - if len(machine.EnabledRoutes) > 0 { - for _, route := range machine.EnabledRoutes { - if route == ExitRouteV4 || route == ExitRouteV6 { - continue - } - - primaryRoutes = append(primaryRoutes, route) - } + enabledRoutes, err := h.GetEnabledRoutes(&machine) + if err != nil { + return nil, err } + allowedIPs = append(allowedIPs, enabledRoutes...) + + primaryRoutes, err := h.getMachinePrimaryRoutes(&machine) + if err != nil { + return nil, err + } + primaryPrefixes := Routes(primaryRoutes).toPrefixes() + var derp string if machine.HostInfo.NetInfo != nil { derp = fmt.Sprintf("127.3.3.40:%d", machine.HostInfo.NetInfo.PreferredDERP) @@ -733,7 +725,7 @@ func (machine Machine) toNode( DiscoKey: discoKey, Addresses: addrs, AllowedIPs: allowedIPs, - PrimaryRoutes: primaryRoutes, + PrimaryRoutes: primaryPrefixes, Endpoints: machine.Endpoints, DERP: derp, @@ -927,21 +919,66 @@ func (h *Headscale) RegisterMachine(machine Machine, return &machine, nil } -func (machine *Machine) GetAdvertisedRoutes() []netip.Prefix { - return machine.HostInfo.RoutableIPs +// GetAdvertisedRoutes returns the routes that are be advertised by the given machine. +func (h *Headscale) GetAdvertisedRoutes(machine *Machine) ([]netip.Prefix, error) { + routes := []Route{} + + err := h.db. + Preload("Machine"). + Where("machine_id = ? AND advertised = ?", machine.ID, true).Find(&routes).Error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + log.Error(). + Caller(). + Err(err). + Str("machine", machine.Hostname). + Msg("Could not get advertised routes for machine") + return nil, err + } + + prefixes := []netip.Prefix{} + for _, route := range routes { + prefixes = append(prefixes, netip.Prefix(route.Prefix)) + } + + return prefixes, nil } -func (machine *Machine) GetEnabledRoutes() []netip.Prefix { - return machine.EnabledRoutes +// GetEnabledRoutes returns the routes that are enabled for the machine. +func (h *Headscale) GetEnabledRoutes(machine *Machine) ([]netip.Prefix, error) { + routes := []Route{} + + err := h.db. + Preload("Machine"). + Where("machine_id = ? AND advertised = ? AND enabled = ?", machine.ID, true, true). + Find(&routes).Error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + log.Error(). + Caller(). + Err(err). + Str("machine", machine.Hostname). + Msg("Could not get enabled routes for machine") + return nil, err + } + + prefixes := []netip.Prefix{} + for _, route := range routes { + prefixes = append(prefixes, netip.Prefix(route.Prefix)) + } + + return prefixes, nil } -func (machine *Machine) IsRoutesEnabled(routeStr string) bool { +func (h *Headscale) IsRoutesEnabled(machine *Machine, routeStr string) bool { route, err := netip.ParsePrefix(routeStr) if err != nil { return false } - enabledRoutes := machine.GetEnabledRoutes() + enabledRoutes, err := h.GetEnabledRoutes(machine) + if err != nil { + log.Error().Err(err).Msg("Could not get enabled routes") + return false + } for _, enabledRoute := range enabledRoutes { if route == enabledRoute { @@ -952,8 +989,7 @@ func (machine *Machine) IsRoutesEnabled(routeStr string) bool { return false } -// EnableNodeRoute enables new routes based on a list of new routes. It will _replace_ the -// previous list of routes. +// EnableRoutes enables new routes based on a list of new routes. func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error { newRoutes := make([]netip.Prefix, len(routeStrs)) for index, routeStr := range routeStrs { @@ -965,8 +1001,13 @@ func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error { newRoutes[index] = route } + advertisedRoutes, err := h.GetAdvertisedRoutes(machine) + if err != nil { + return err + } + for _, newRoute := range newRoutes { - if !contains(machine.GetAdvertisedRoutes(), newRoute) { + if !contains(advertisedRoutes, newRoute) { return fmt.Errorf( "route (%s) is not available on node %s: %w", machine.Hostname, @@ -975,52 +1016,77 @@ func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error { } } - machine.EnabledRoutes = newRoutes + // Separate loop so we don't leave things in a half-updated state + for _, prefix := range newRoutes { + route := Route{} + err := h.db.Preload("Machine"). + Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)). + First(&route).Error + if err == nil { + route.Enabled = true - if err := h.db.Save(machine).Error; err != nil { - return fmt.Errorf("failed enable routes for machine in the database: %w", err) + // Mark already as primary if there is only this node offering this subnet + // (and is not an exit route) + if prefix != ExitRouteV4 && prefix != ExitRouteV6 { + route.IsPrimary = h.isUniquePrefix(route) + } + + err = h.db.Save(&route).Error + if err != nil { + return fmt.Errorf("failed to enable route: %w", err) + } + } else { + return fmt.Errorf("failed to find route: %w", err) + } } return nil } -// Enabled any routes advertised by a machine that match the ACL autoApprovers policy. -func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) { +// EnableAutoApprovedRoutes enables any routes advertised by a machine that match the ACL autoApprovers policy. +func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) error { if len(machine.IPAddresses) == 0 { - return // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs + return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs } - approvedRoutes := make([]netip.Prefix, 0, len(machine.HostInfo.RoutableIPs)) - thisMachine := []Machine{*machine} + routes := []Route{} + err := h.db. + Preload("Machine"). + Where("machine_id = ? AND advertised = true AND enabled = false", machine.ID).Find(&routes).Error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + log.Error(). + Caller(). + Err(err). + Str("machine", machine.Hostname). + Msg("Could not get advertised routes for machine") - for _, advertisedRoute := range machine.HostInfo.RoutableIPs { - if contains(machine.EnabledRoutes, advertisedRoute) { - continue // Skip routes that are already enabled for the node - } + return err + } - routeApprovers, err := h.aclPolicy.AutoApprovers.GetRouteApprovers( - advertisedRoute, - ) + approvedRoutes := []Route{} + + for _, advertisedRoute := range routes { + routeApprovers, err := h.aclPolicy.AutoApprovers.GetRouteApprovers(netip.Prefix(advertisedRoute.Prefix)) if err != nil { log.Err(err). Str("advertisedRoute", advertisedRoute.String()). Uint64("machineId", machine.ID). Msg("Failed to resolve autoApprovers for advertised route") - return + return err } for _, approvedAlias := range routeApprovers { if approvedAlias == machine.Namespace.Name { approvedRoutes = append(approvedRoutes, advertisedRoute) } else { - approvedIps, err := expandAlias(thisMachine, *h.aclPolicy, approvedAlias, h.cfg.OIDC.StripEmaildomain) + approvedIps, err := expandAlias([]Machine{*machine}, *h.aclPolicy, approvedAlias, h.cfg.OIDC.StripEmaildomain) if err != nil { log.Err(err). Str("alias", approvedAlias). Msg("Failed to expand alias when processing autoApprovers policy") - return + return err } // approvedIPs should contain all of machine's IPs if it matches the rule, so check for first @@ -1032,20 +1098,33 @@ func (h *Headscale) EnableAutoApprovedRoutes(machine *Machine) { } for _, approvedRoute := range approvedRoutes { - if !contains(machine.EnabledRoutes, approvedRoute) { - log.Info(). - Str("route", approvedRoute.String()). - Uint64("client", machine.ID). - Msg("Enabling autoApproved route for client") - machine.EnabledRoutes = append(machine.EnabledRoutes, approvedRoute) + approvedRoute.Enabled = true + err = h.db.Save(&approvedRoute).Error + if err != nil { + log.Err(err). + Str("approvedRoute", approvedRoute.String()). + Uint64("machineId", machine.ID). + Msg("Failed to enable approved route") + + return err } } + + return nil } -func (machine *Machine) RoutesToProto() *v1.Routes { - availableRoutes := machine.GetAdvertisedRoutes() +func (h *Headscale) RoutesToProto(machine *Machine) *v1.Routes { + availableRoutes, err := h.GetAdvertisedRoutes(machine) + if err != nil { + log.Error().Err(err).Msg("Could not get advertised routes") + return nil + } - enabledRoutes := machine.GetEnabledRoutes() + enabledRoutes, err := h.GetEnabledRoutes(machine) + if err != nil { + log.Error().Err(err).Msg("Could not get enabled routes") + return nil + } return &v1.Routes{ AdvertisedRoutes: ipPrefixToString(availableRoutes), diff --git a/machine_test.go b/machine_test.go index b13ecd0c..c5073233 100644 --- a/machine_test.go +++ b/machine_test.go @@ -1153,9 +1153,14 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { app.db.Save(&machine) + err = app.processMachineRoutes(&machine) + c.Assert(err, check.IsNil) + machine0ByID, err := app.GetMachineByID(0) c.Assert(err, check.IsNil) app.EnableAutoApprovedRoutes(machine0ByID) - c.Assert(machine0ByID.GetEnabledRoutes(), check.HasLen, 3) + enabledRoutes, err := app.GetEnabledRoutes(machine0ByID) + c.Assert(err, check.IsNil) + c.Assert(enabledRoutes, check.HasLen, 3) } diff --git a/routes.go b/routes.go index 3ba710be..36d67a90 100644 --- a/routes.go +++ b/routes.go @@ -11,6 +11,11 @@ const ( ErrRouteIsNotAvailable = Error("route is not available") ) +var ( + ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0") + ExitRouteV6 = netip.MustParsePrefix("::/0") +) + type Route struct { gorm.Model @@ -37,6 +42,18 @@ func (rs Routes) toPrefixes() []netip.Prefix { return prefixes } +// isUniquePrefix returns if there is another machine providing the same route already +func (h *Headscale) isUniquePrefix(route Route) bool { + var count int64 + h.db. + Model(&Route{}). + Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?", + route.Prefix, + route.MachineID, + true, true).Count(&count) + return count == 0 +} + // getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover) // Exit nodes are not considered for this, as they are never marked as Primary func (h *Headscale) getMachinePrimaryRoutes(m *Machine) ([]Route, error) { diff --git a/routes_test.go b/routes_test.go index 6566f3e5..2560898e 100644 --- a/routes_test.go +++ b/routes_test.go @@ -37,17 +37,17 @@ func (s *Suite) TestGetRoutes(c *check.C) { } app.db.Save(&machine) - advertisedRoutes, err := app.GetAdvertisedNodeRoutes( - "test", - "test_get_route_machine", - ) + err = app.processMachineRoutes(&machine) c.Assert(err, check.IsNil) - c.Assert(len(*advertisedRoutes), check.Equals, 1) - err = app.EnableNodeRoute("test", "test_get_route_machine", "192.168.0.0/24") + advertisedRoutes, err := app.GetAdvertisedRoutes(&machine) + c.Assert(err, check.IsNil) + c.Assert(len(advertisedRoutes), check.Equals, 1) + + err = app.EnableRoutes(&machine, "192.168.0.0/24") c.Assert(err, check.NotNil) - err = app.EnableNodeRoute("test", "test_get_route_machine", "10.0.0.0/24") + err = app.EnableRoutes(&machine, "10.0.0.0/24") c.Assert(err, check.IsNil) } @@ -88,48 +88,124 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { } app.db.Save(&machine) - availableRoutes, err := app.GetAdvertisedNodeRoutes( - "test", - "test_enable_route_machine", - ) + err = app.processMachineRoutes(&machine) c.Assert(err, check.IsNil) - c.Assert(len(*availableRoutes), check.Equals, 2) - noEnabledRoutes, err := app.GetEnabledNodeRoutes( - "test", - "test_enable_route_machine", - ) + availableRoutes, err := app.GetAdvertisedRoutes(&machine) + c.Assert(err, check.IsNil) + c.Assert(err, check.IsNil) + c.Assert(len(availableRoutes), check.Equals, 2) + + noEnabledRoutes, err := app.GetEnabledRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(len(noEnabledRoutes), check.Equals, 0) - err = app.EnableNodeRoute("test", "test_enable_route_machine", "192.168.0.0/24") + err = app.EnableRoutes(&machine, "192.168.0.0/24") c.Assert(err, check.NotNil) - err = app.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24") + err = app.EnableRoutes(&machine, "10.0.0.0/24") c.Assert(err, check.IsNil) - enabledRoutes, err := app.GetEnabledNodeRoutes("test", "test_enable_route_machine") + enabledRoutes, err := app.GetEnabledRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes), check.Equals, 1) // Adding it twice will just let it pass through - err = app.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24") + err = app.EnableRoutes(&machine, "10.0.0.0/24") c.Assert(err, check.IsNil) - enableRoutesAfterDoubleApply, err := app.GetEnabledNodeRoutes( - "test", - "test_enable_route_machine", - ) + enableRoutesAfterDoubleApply, err := app.GetEnabledRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1) - err = app.EnableNodeRoute("test", "test_enable_route_machine", "150.0.10.0/25") + err = app.EnableRoutes(&machine, "150.0.10.0/25") c.Assert(err, check.IsNil) - enabledRoutesWithAdditionalRoute, err := app.GetEnabledNodeRoutes( - "test", - "test_enable_route_machine", - ) + enabledRoutesWithAdditionalRoute, err := app.GetEnabledRoutes(&machine) c.Assert(err, check.IsNil) c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2) } + +func (s *Suite) TestIsUniquePrefix(c *check.C) { + namespace, err := app.CreateNamespace("test") + c.Assert(err, check.IsNil) + + pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil, nil) + c.Assert(err, check.IsNil) + + _, err = app.GetMachine("test", "test_enable_route_machine") + c.Assert(err, check.NotNil) + + route, err := netip.ParsePrefix( + "10.0.0.0/24", + ) + c.Assert(err, check.IsNil) + + route2, err := netip.ParsePrefix( + "150.0.10.0/25", + ) + c.Assert(err, check.IsNil) + + hostInfo1 := tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{route, route2}, + } + machine1 := Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "test_enable_route_machine", + NamespaceID: namespace.ID, + RegisterMethod: RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + HostInfo: HostInfo(hostInfo1), + } + app.db.Save(&machine1) + + err = app.processMachineRoutes(&machine1) + c.Assert(err, check.IsNil) + + err = app.EnableRoutes(&machine1, route.String()) + c.Assert(err, check.IsNil) + + err = app.EnableRoutes(&machine1, route2.String()) + c.Assert(err, check.IsNil) + + hostInfo2 := tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{route2}, + } + machine2 := Machine{ + ID: 0, + MachineKey: "foo", + NodeKey: "bar", + DiscoKey: "faa", + Hostname: "test_enable_route_machine", + NamespaceID: namespace.ID, + RegisterMethod: RegisterMethodAuthKey, + AuthKeyID: uint(pak.ID), + HostInfo: HostInfo(hostInfo2), + } + app.db.Save(&machine2) + + err = app.processMachineRoutes(&machine2) + c.Assert(err, check.IsNil) + + err = app.EnableRoutes(&machine2, route2.String()) + c.Assert(err, check.IsNil) + + enabledRoutes1, err := app.GetEnabledRoutes(&machine1) + c.Assert(err, check.IsNil) + c.Assert(len(enabledRoutes1), check.Equals, 2) + + enabledRoutes2, err := app.GetEnabledRoutes(&machine2) + c.Assert(err, check.IsNil) + c.Assert(len(enabledRoutes2), check.Equals, 1) + + routes, err := app.getMachinePrimaryRoutes(&machine1) + c.Assert(err, check.IsNil) + c.Assert(len(routes), check.Equals, 2) + + routes, err = app.getMachinePrimaryRoutes(&machine2) + c.Assert(err, check.IsNil) + c.Assert(len(routes), check.Equals, 0) +}