diff --git a/hscontrol/db/machine.go b/hscontrol/db/machine.go index 849ce0c5..f19b204d 100644 --- a/hscontrol/db/machine.go +++ b/hscontrol/db/machine.go @@ -48,7 +48,6 @@ func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error Preload("Routes"). Where("node_key <> ?", machine.NodeKey).Find(&machines).Error; err != nil { - return types.Machines{}, err } @@ -70,7 +69,6 @@ func (hsdb *HSDatabase) ListMachines() ([]types.Machine, error) { Preload("User"). Preload("Routes"). Find(&machines).Error; err != nil { - return nil, err } @@ -85,7 +83,6 @@ func (hsdb *HSDatabase) ListMachinesByGivenName(givenName string) (types.Machine Preload("User"). Preload("Routes"). Where("given_name = ?", givenName).Find(&machines).Error; err != nil { - return nil, err } @@ -129,34 +126,34 @@ func (hsdb *HSDatabase) GetMachineByGivenName( // GetMachineByID finds a Machine by ID and returns the Machine struct. func (hsdb *HSDatabase) GetMachineByID(id uint64) (*types.Machine, error) { - m := types.Machine{} + mach := types.Machine{} if result := hsdb.db. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). Preload("Routes"). - Find(&types.Machine{ID: id}).First(&m); result.Error != nil { + Find(&types.Machine{ID: id}).First(&mach); result.Error != nil { return nil, result.Error } - return &m, nil + return &mach, nil } // GetMachineByMachineKey finds a Machine by its MachineKey and returns the Machine struct. func (hsdb *HSDatabase) GetMachineByMachineKey( machineKey key.MachinePublic, ) (*types.Machine, error) { - m := types.Machine{} + mach := types.Machine{} if result := hsdb.db. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). Preload("Routes"). - First(&m, "machine_key = ?", util.MachinePublicKeyStripPrefix(machineKey)); result.Error != nil { + First(&mach, "machine_key = ?", util.MachinePublicKeyStripPrefix(machineKey)); result.Error != nil { return nil, result.Error } - return &m, nil + return &mach, nil } // GetMachineByNodeKey finds a Machine by its current NodeKey. diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index d281452d..4e91a2cb 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -4,12 +4,10 @@ import ( "net/netip" "time" - "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" "tailscale.com/tailcfg" - "tailscale.com/types/key" ) func (s *Suite) TestGetRoutes(c *check.C) { @@ -365,144 +363,6 @@ func (s *Suite) TestSubnetFailover(c *check.C) { c.Assert(channelUpdates, check.Equals, int32(6)) } -// TestAllowedIPRoutes tests that the AllowedIPs are correctly set for a node, -// including both the primary routes the node is responsible for, and the -// exit node routes if enabled. -func (s *Suite) TestAllowedIPRoutes(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetMachine("test", "test_enable_route_machine") - c.Assert(err, check.NotNil) - - prefix, err := netip.ParsePrefix( - "10.0.0.0/24", - ) - c.Assert(err, check.IsNil) - - prefix2, err := netip.ParsePrefix( - "150.0.10.0/25", - ) - c.Assert(err, check.IsNil) - - prefixExitNodeV4, err := netip.ParsePrefix( - "0.0.0.0/0", - ) - c.Assert(err, check.IsNil) - - prefixExitNodeV6, err := netip.ParsePrefix( - "::/0", - ) - c.Assert(err, check.IsNil) - - hostInfo1 := tailcfg.Hostinfo{ - RoutableIPs: []netip.Prefix{prefix, prefix2, prefixExitNodeV4, prefixExitNodeV6}, - } - - nodeKey := key.NewNode() - discoKey := key.NewDisco() - machineKey := key.NewMachine() - - now := time.Now() - machine1 := types.Machine{ - ID: 1, - MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()), - NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()), - DiscoKey: util.DiscoPublicKeyStripPrefix(discoKey.Public()), - Hostname: "test_enable_route_machine", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - HostInfo: types.HostInfo(hostInfo1), - LastSeen: &now, - } - db.db.Save(&machine1) - - err = db.ProcessMachineRoutes(&machine1) - c.Assert(err, check.IsNil) - - err = db.enableRoutes(&machine1, prefix.String()) - c.Assert(err, check.IsNil) - - // We do not enable this one on purpose to test that it is not enabled - // err = db.enableRoutes(&machine1, prefix2.String()) - // c.Assert(err, check.IsNil) - - routes, err := db.GetMachineRoutes(&machine1) - c.Assert(err, check.IsNil) - for _, route := range routes { - if route.IsExitRoute() { - err = db.EnableRoute(uint64(route.ID)) - c.Assert(err, check.IsNil) - - // We only enable one exit route, so we can test that both are enabled - break - } - } - - err = db.HandlePrimarySubnetFailover() - c.Assert(err, check.IsNil) - - enabledRoutes1, err := db.GetEnabledRoutes(&machine1) - c.Assert(err, check.IsNil) - c.Assert(len(enabledRoutes1), check.Equals, 3) - - peer, err := db.TailNode(machine1, &policy.ACLPolicy{}, nil) - c.Assert(err, check.IsNil) - - c.Assert(len(peer.AllowedIPs), check.Equals, 3) - - foundExitNodeV4 := false - foundExitNodeV6 := false - for _, allowedIP := range peer.AllowedIPs { - if allowedIP == prefixExitNodeV4 { - foundExitNodeV4 = true - } - if allowedIP == prefixExitNodeV6 { - foundExitNodeV6 = true - } - } - - c.Assert(foundExitNodeV4, check.Equals, true) - c.Assert(foundExitNodeV6, check.Equals, true) - - // Now we disable only one of the exit routes - // and we see if both are disabled - var exitRouteV4 types.Route - for _, route := range routes { - if route.IsExitRoute() && netip.Prefix(route.Prefix) == prefixExitNodeV4 { - exitRouteV4 = route - - break - } - } - - err = db.DisableRoute(uint64(exitRouteV4.ID)) - c.Assert(err, check.IsNil) - - enabledRoutes1, err = db.GetEnabledRoutes(&machine1) - c.Assert(err, check.IsNil) - c.Assert(len(enabledRoutes1), check.Equals, 1) - - // and now we delete only one of the exit routes - // and we check if both are deleted - routes, err = db.GetMachineRoutes(&machine1) - c.Assert(err, check.IsNil) - c.Assert(len(routes), check.Equals, 4) - - err = db.DeleteRoute(uint64(exitRouteV4.ID)) - c.Assert(err, check.IsNil) - - routes, err = db.GetMachineRoutes(&machine1) - c.Assert(err, check.IsNil) - c.Assert(len(routes), check.Equals, 2) - - c.Assert(channelUpdates, check.Equals, int32(2)) -} - func (s *Suite) TestDeleteRoutes(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index a8a3f40b..3ffff7d3 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -99,6 +99,12 @@ func TestTailNode(t *testing.T) { Enabled: true, IsPrimary: true, }, + { + Prefix: types.IPPrefix(netip.MustParsePrefix("172.0.0.0/10")), + Advertised: true, + Enabled: false, + IsPrimary: true, + }, }, CreatedAt: created, },