mirror of
https://github.com/juanfont/headscale.git
synced 2025-08-12 00:17:35 +00:00
Split code into modules
This is a massive commit that restructures the code into modules: db/ All functions related to modifying the Database types/ All type definitions and methods that can be exclusivly used on these types without dependencies policy/ All Policy related code, now without dependencies on the Database. policy/matcher/ Dedicated code to match machines in a list of FilterRules Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:

committed by
Kristoffer Dalby

parent
14e29a7bee
commit
feb15365b5
480
hscontrol/db/acls_test.go
Normal file
480
hscontrol/db/acls_test.go
Normal file
@@ -0,0 +1,480 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"gopkg.in/check.v1"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// TODO(kradalby):
|
||||
// Convert these tests to being non-database dependent and table driven. They are
|
||||
// very verbose, and dont really need the database.
|
||||
|
||||
func (s *Suite) TestSshRules(c *check.C) {
|
||||
envknob.Setenv("HEADSCALE_EXPERIMENTAL_FEATURE_SSH", "1")
|
||||
|
||||
user, err := db.CreateUser("user1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("user1", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
hostInfo := tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "testmachine",
|
||||
RequestTags: []string{"tag:test"},
|
||||
}
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
}
|
||||
err = db.MachineSave(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
aclPolicy := &policy.ACLPolicy{
|
||||
Groups: policy.Groups{
|
||||
"group:test": []string{"user1"},
|
||||
},
|
||||
Hosts: policy.Hosts{
|
||||
"client": netip.PrefixFrom(netip.MustParseAddr("100.64.99.42"), 32),
|
||||
},
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
Destinations: []string{"*:*"},
|
||||
},
|
||||
},
|
||||
SSHs: []policy.SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"group:test"},
|
||||
Destinations: []string{"client"},
|
||||
Users: []string{"autogroup:nonroot"},
|
||||
},
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
Destinations: []string{"client"},
|
||||
Users: []string{"autogroup:nonroot"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, sshPolicy, err := policy.GenerateFilterRules(aclPolicy, types.Machines{}, false)
|
||||
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(sshPolicy, check.NotNil)
|
||||
c.Assert(sshPolicy.Rules, check.HasLen, 2)
|
||||
c.Assert(sshPolicy.Rules[0].SSHUsers, check.HasLen, 1)
|
||||
c.Assert(sshPolicy.Rules[0].Principals, check.HasLen, 1)
|
||||
c.Assert(sshPolicy.Rules[0].Principals[0].UserLogin, check.Matches, "user1")
|
||||
|
||||
c.Assert(sshPolicy.Rules[1].SSHUsers, check.HasLen, 1)
|
||||
c.Assert(sshPolicy.Rules[1].Principals, check.HasLen, 1)
|
||||
c.Assert(sshPolicy.Rules[1].Principals[0].NodeIP, check.Matches, "*")
|
||||
}
|
||||
|
||||
// this test should validate that we can expand a group in a TagOWner section and
|
||||
// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
|
||||
// the tag is matched in the Sources section.
|
||||
func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) {
|
||||
user, err := db.CreateUser("user1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("user1", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
hostInfo := tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "testmachine",
|
||||
RequestTags: []string{"tag:test"},
|
||||
}
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
}
|
||||
err = db.MachineSave(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pol := &policy.ACLPolicy{
|
||||
Groups: policy.Groups{"group:test": []string{"user1", "user2"}},
|
||||
TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}},
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"tag:test"},
|
||||
Destinations: []string{"*:*"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
machines, err := db.ListMachines()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(rules, check.HasLen, 1)
|
||||
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
|
||||
c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.1/32")
|
||||
}
|
||||
|
||||
// this test should validate that we can expand a group in a TagOWner section and
|
||||
// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
|
||||
// the tag is matched in the Destinations section.
|
||||
func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) {
|
||||
user, err := db.CreateUser("user1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("user1", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
hostInfo := tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "testmachine",
|
||||
RequestTags: []string{"tag:test"},
|
||||
}
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "12345",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
}
|
||||
err = db.MachineSave(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pol := &policy.ACLPolicy{
|
||||
Groups: policy.Groups{"group:test": []string{"user1", "user2"}},
|
||||
TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}},
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
Destinations: []string{"tag:test:*"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
machines, err := db.ListMachines()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(rules, check.HasLen, 1)
|
||||
c.Assert(rules[0].DstPorts, check.HasLen, 1)
|
||||
c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
|
||||
}
|
||||
|
||||
// need a test with:
|
||||
// tag on a host that isn't owned by a tag owners. So the user
|
||||
// of the host should be valid.
|
||||
func (s *Suite) TestInvalidTagValidUser(c *check.C) {
|
||||
user, err := db.CreateUser("user1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("user1", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
hostInfo := tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "testmachine",
|
||||
RequestTags: []string{"tag:foo"},
|
||||
}
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "12345",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
}
|
||||
err = db.MachineSave(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pol := &policy.ACLPolicy{
|
||||
TagOwners: policy.TagOwners{"tag:test": []string{"user1"}},
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"user1"},
|
||||
Destinations: []string{"*:*"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
machines, err := db.ListMachines()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(rules, check.HasLen, 1)
|
||||
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
|
||||
c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.1/32")
|
||||
}
|
||||
|
||||
// tag on a host is owned by a tag owner, the tag is valid.
|
||||
// an ACL rule is matching the tag to a user. It should not be valid since the
|
||||
// host should be tied to the tag now.
|
||||
func (s *Suite) TestValidTagInvalidUser(c *check.C) {
|
||||
user, err := db.CreateUser("user1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("user1", "webserver")
|
||||
c.Assert(err, check.NotNil)
|
||||
hostInfo := tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "webserver",
|
||||
RequestTags: []string{"tag:webapp"},
|
||||
}
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "12345",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "webserver",
|
||||
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
}
|
||||
err = db.MachineSave(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("user1", "user")
|
||||
hostInfo2 := tailcfg.Hostinfo{
|
||||
OS: "debian",
|
||||
Hostname: "Hostname",
|
||||
}
|
||||
c.Assert(err, check.NotNil)
|
||||
machine = types.Machine{
|
||||
ID: 2,
|
||||
MachineKey: "56789",
|
||||
NodeKey: "bar2",
|
||||
DiscoKey: "faab",
|
||||
Hostname: "user",
|
||||
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")},
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo2),
|
||||
}
|
||||
err = db.MachineSave(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pol := &policy.ACLPolicy{
|
||||
TagOwners: policy.TagOwners{"tag:webapp": []string{"user1"}},
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"user1"},
|
||||
Destinations: []string{"tag:webapp:80,443"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
machines, err := db.ListMachines()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(rules, check.HasLen, 1)
|
||||
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
|
||||
c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.2/32")
|
||||
c.Assert(rules[0].DstPorts, check.HasLen, 2)
|
||||
c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(80))
|
||||
c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(80))
|
||||
c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
|
||||
c.Assert(rules[0].DstPorts[1].Ports.First, check.Equals, uint16(443))
|
||||
c.Assert(rules[0].DstPorts[1].Ports.Last, check.Equals, uint16(443))
|
||||
c.Assert(rules[0].DstPorts[1].IP, check.Equals, "100.64.0.1/32")
|
||||
}
|
||||
|
||||
func (s *Suite) TestPortUser(c *check.C) {
|
||||
user, err := db.CreateUser("testuser")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("testuser", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
ips, _ := db.getAvailableIPs()
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "12345",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: ips,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
err = db.MachineSave(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
acl := []byte(`
|
||||
{
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"testuser",
|
||||
],
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`)
|
||||
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(pol, check.NotNil)
|
||||
|
||||
machines, err := db.ListMachines()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(rules, check.NotNil)
|
||||
|
||||
c.Assert(rules, check.HasLen, 1)
|
||||
c.Assert(rules[0].DstPorts, check.HasLen, 1)
|
||||
c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
|
||||
c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
|
||||
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
|
||||
c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip")
|
||||
c.Assert(len(ips), check.Equals, 1)
|
||||
c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32")
|
||||
}
|
||||
|
||||
func (s *Suite) TestPortGroup(c *check.C) {
|
||||
user, err := db.CreateUser("testuser")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("testuser", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
ips, _ := db.getAvailableIPs()
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: ips,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
err = db.MachineSave(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
acl := []byte(`
|
||||
{
|
||||
"groups": {
|
||||
"group:example": [
|
||||
"testuser",
|
||||
],
|
||||
},
|
||||
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"group:example",
|
||||
],
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`)
|
||||
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machines, err := db.ListMachines()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(rules, check.NotNil)
|
||||
|
||||
c.Assert(rules, check.HasLen, 1)
|
||||
c.Assert(rules[0].DstPorts, check.HasLen, 1)
|
||||
c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
|
||||
c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
|
||||
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
|
||||
c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip")
|
||||
c.Assert(len(ips), check.Equals, 1)
|
||||
c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32")
|
||||
}
|
99
hscontrol/db/addresses.go
Normal file
99
hscontrol/db/addresses.go
Normal file
@@ -0,0 +1,99 @@
|
||||
// Codehere is mostly taken from github.com/tailscale/tailscale
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"go4.org/netipx"
|
||||
)
|
||||
|
||||
var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP")
|
||||
|
||||
func (hsdb *HSDatabase) getAvailableIPs() (types.MachineAddresses, error) {
|
||||
var ips types.MachineAddresses
|
||||
var err error
|
||||
for _, ipPrefix := range hsdb.ipPrefixes {
|
||||
var ip *netip.Addr
|
||||
ip, err = hsdb.getAvailableIP(ipPrefix)
|
||||
if err != nil {
|
||||
return ips, err
|
||||
}
|
||||
ips = append(ips, *ip)
|
||||
}
|
||||
|
||||
return ips, err
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, error) {
|
||||
usedIps, err := hsdb.getUsedIPs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ipPrefixNetworkAddress, ipPrefixBroadcastAddress := util.GetIPPrefixEndpoints(ipPrefix)
|
||||
|
||||
// Get the first IP in our prefix
|
||||
ip := ipPrefixNetworkAddress.Next()
|
||||
|
||||
for {
|
||||
if !ipPrefix.Contains(ip) {
|
||||
return nil, ErrCouldNotAllocateIP
|
||||
}
|
||||
|
||||
switch {
|
||||
case ip.Compare(ipPrefixBroadcastAddress) == 0:
|
||||
fallthrough
|
||||
case usedIps.Contains(ip):
|
||||
fallthrough
|
||||
case ip == netip.Addr{} || ip.IsLoopback():
|
||||
ip = ip.Next()
|
||||
|
||||
continue
|
||||
|
||||
default:
|
||||
return &ip, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getUsedIPs() (*netipx.IPSet, error) {
|
||||
// FIXME: This really deserves a better data model,
|
||||
// but this was quick to get running and it should be enough
|
||||
// to begin experimenting with a dual stack tailnet.
|
||||
var addressesSlices []string
|
||||
hsdb.db.Model(&types.Machine{}).Pluck("ip_addresses", &addressesSlices)
|
||||
|
||||
var ips netipx.IPSetBuilder
|
||||
for _, slice := range addressesSlices {
|
||||
var machineAddresses types.MachineAddresses
|
||||
err := machineAddresses.Scan(slice)
|
||||
if err != nil {
|
||||
return &netipx.IPSet{}, fmt.Errorf(
|
||||
"failed to read ip from database: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
for _, ip := range machineAddresses {
|
||||
ips.Add(ip)
|
||||
}
|
||||
}
|
||||
|
||||
ipSet, err := ips.IPSet()
|
||||
if err != nil {
|
||||
return &netipx.IPSet{}, fmt.Errorf(
|
||||
"failed to build IP Set: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
return ipSet, nil
|
||||
}
|
197
hscontrol/db/addresses_test.go
Normal file
197
hscontrol/db/addresses_test.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"go4.org/netipx"
|
||||
"gopkg.in/check.v1"
|
||||
)
|
||||
|
||||
func (s *Suite) TestGetAvailableIp(c *check.C) {
|
||||
ips, err := db.getAvailableIPs()
|
||||
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
expected := netip.MustParseAddr("10.27.0.1")
|
||||
|
||||
c.Assert(len(ips), check.Equals, 1)
|
||||
c.Assert(ips[0].String(), check.Equals, expected.String())
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetUsedIps(c *check.C) {
|
||||
ips, err := db.getAvailableIPs()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
user, err := db.CreateUser("test-ip")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
IPAddresses: ips,
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
|
||||
usedIps, err := db.getUsedIPs()
|
||||
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
expected := netip.MustParseAddr("10.27.0.1")
|
||||
expectedIPSetBuilder := netipx.IPSetBuilder{}
|
||||
expectedIPSetBuilder.Add(expected)
|
||||
expectedIPSet, _ := expectedIPSetBuilder.IPSet()
|
||||
|
||||
c.Assert(usedIps.Equal(expectedIPSet), check.Equals, true)
|
||||
c.Assert(usedIps.Contains(expected), check.Equals, true)
|
||||
|
||||
machine1, err := db.GetMachineByID(0)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(machine1.IPAddresses), check.Equals, 1)
|
||||
c.Assert(machine1.IPAddresses[0], check.Equals, expected)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||
user, err := db.CreateUser("test-ip-multi")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
for index := 1; index <= 350; index++ {
|
||||
db.ipAllocationMutex.Lock()
|
||||
|
||||
ips, err := db.getAvailableIPs()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
machine := types.Machine{
|
||||
ID: uint64(index),
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
IPAddresses: ips,
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
|
||||
db.ipAllocationMutex.Unlock()
|
||||
}
|
||||
|
||||
usedIps, err := db.getUsedIPs()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
expected0 := netip.MustParseAddr("10.27.0.1")
|
||||
expected9 := netip.MustParseAddr("10.27.0.10")
|
||||
expected300 := netip.MustParseAddr("10.27.0.45")
|
||||
|
||||
notExpectedIPSetBuilder := netipx.IPSetBuilder{}
|
||||
notExpectedIPSetBuilder.Add(expected0)
|
||||
notExpectedIPSetBuilder.Add(expected9)
|
||||
notExpectedIPSetBuilder.Add(expected300)
|
||||
notExpectedIPSet, err := notExpectedIPSetBuilder.IPSet()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
// We actually expect it to be a lot larger
|
||||
c.Assert(usedIps.Equal(notExpectedIPSet), check.Equals, false)
|
||||
|
||||
c.Assert(usedIps.Contains(expected0), check.Equals, true)
|
||||
c.Assert(usedIps.Contains(expected9), check.Equals, true)
|
||||
c.Assert(usedIps.Contains(expected300), check.Equals, true)
|
||||
|
||||
// Check that we can read back the IPs
|
||||
machine1, err := db.GetMachineByID(1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(machine1.IPAddresses), check.Equals, 1)
|
||||
c.Assert(
|
||||
machine1.IPAddresses[0],
|
||||
check.Equals,
|
||||
netip.MustParseAddr("10.27.0.1"),
|
||||
)
|
||||
|
||||
machine50, err := db.GetMachineByID(50)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(machine50.IPAddresses), check.Equals, 1)
|
||||
c.Assert(
|
||||
machine50.IPAddresses[0],
|
||||
check.Equals,
|
||||
netip.MustParseAddr("10.27.0.50"),
|
||||
)
|
||||
|
||||
expectedNextIP := netip.MustParseAddr("10.27.1.95")
|
||||
nextIP, err := db.getAvailableIPs()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(nextIP), check.Equals, 1)
|
||||
c.Assert(nextIP[0].String(), check.Equals, expectedNextIP.String())
|
||||
|
||||
// If we call get Available again, we should receive
|
||||
// the same IP, as it has not been reserved.
|
||||
nextIP2, err := db.getAvailableIPs()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(nextIP2), check.Equals, 1)
|
||||
c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String())
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
|
||||
ips, err := db.getAvailableIPs()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
expected := netip.MustParseAddr("10.27.0.1")
|
||||
|
||||
c.Assert(len(ips), check.Equals, 1)
|
||||
c.Assert(ips[0].String(), check.Equals, expected.String())
|
||||
|
||||
user, err := db.CreateUser("test-ip")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
|
||||
ips2, err := db.getAvailableIPs()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(ips2), check.Equals, 1)
|
||||
c.Assert(ips2[0].String(), check.Equals, expected.String())
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
125
hscontrol/db/api_key.go
Normal file
125
hscontrol/db/api_key.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const (
|
||||
apiPrefixLength = 7
|
||||
apiKeyLength = 32
|
||||
)
|
||||
|
||||
var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey")
|
||||
|
||||
// CreateAPIKey creates a new ApiKey in a user, and returns it.
|
||||
func (hsdb *HSDatabase) CreateAPIKey(
|
||||
expiration *time.Time,
|
||||
) (string, *types.APIKey, error) {
|
||||
prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
toBeHashed, err := util.GenerateRandomStringURLSafe(apiKeyLength)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// Key to return to user, this will only be visible _once_
|
||||
keyStr := prefix + "." + toBeHashed
|
||||
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(toBeHashed), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
key := types.APIKey{
|
||||
Prefix: prefix,
|
||||
Hash: hash,
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
if err := hsdb.db.Save(&key).Error; err != nil {
|
||||
return "", nil, fmt.Errorf("failed to save API key to database: %w", err)
|
||||
}
|
||||
|
||||
return keyStr, &key, nil
|
||||
}
|
||||
|
||||
// ListAPIKeys returns the list of ApiKeys for a user.
|
||||
func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) {
|
||||
keys := []types.APIKey{}
|
||||
if err := hsdb.db.Find(&keys).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// GetAPIKey returns a ApiKey for a given key.
|
||||
func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) {
|
||||
key := types.APIKey{}
|
||||
if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
// GetAPIKeyByID returns a ApiKey for a given id.
|
||||
func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) {
|
||||
key := types.APIKey{}
|
||||
if result := hsdb.db.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
// DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey
|
||||
// does not exist.
|
||||
func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error {
|
||||
if result := hsdb.db.Unscoped().Delete(key); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExpireAPIKey marks a ApiKey as expired.
|
||||
func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error {
|
||||
if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) {
|
||||
prefix, hash, found := strings.Cut(keyStr, ".")
|
||||
if !found {
|
||||
return false, ErrAPIKeyFailedToParse
|
||||
}
|
||||
|
||||
key, err := hsdb.GetAPIKey(prefix)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to validate api key: %w", err)
|
||||
}
|
||||
|
||||
if key.Expiration.Before(time.Now()) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword(key.Hash, []byte(hash)); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
97
hscontrol/db/api_key_test.go
Normal file
97
hscontrol/db/api_key_test.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gopkg.in/check.v1"
|
||||
)
|
||||
|
||||
func (*Suite) TestCreateAPIKey(c *check.C) {
|
||||
apiKeyStr, apiKey, err := db.CreateAPIKey(nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(apiKey, check.NotNil)
|
||||
|
||||
// Did we get a valid key?
|
||||
c.Assert(apiKey.Prefix, check.NotNil)
|
||||
c.Assert(apiKey.Hash, check.NotNil)
|
||||
c.Assert(apiKeyStr, check.Not(check.Equals), "")
|
||||
|
||||
_, err = db.ListAPIKeys()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
keys, err := db.ListAPIKeys()
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(keys), check.Equals, 1)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (*Suite) TestAPIKeyDoesNotExist(c *check.C) {
|
||||
key, err := db.GetAPIKey("does-not-exist")
|
||||
c.Assert(err, check.NotNil)
|
||||
c.Assert(key, check.IsNil)
|
||||
}
|
||||
|
||||
func (*Suite) TestValidateAPIKeyOk(c *check.C) {
|
||||
nowPlus2 := time.Now().Add(2 * time.Hour)
|
||||
apiKeyStr, apiKey, err := db.CreateAPIKey(&nowPlus2)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(apiKey, check.NotNil)
|
||||
|
||||
valid, err := db.ValidateAPIKey(apiKeyStr)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(valid, check.Equals, true)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (*Suite) TestValidateAPIKeyNotOk(c *check.C) {
|
||||
nowMinus2 := time.Now().Add(time.Duration(-2) * time.Hour)
|
||||
apiKeyStr, apiKey, err := db.CreateAPIKey(&nowMinus2)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(apiKey, check.NotNil)
|
||||
|
||||
valid, err := db.ValidateAPIKey(apiKeyStr)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(valid, check.Equals, false)
|
||||
|
||||
now := time.Now()
|
||||
apiKeyStrNow, apiKey, err := db.CreateAPIKey(&now)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(apiKey, check.NotNil)
|
||||
|
||||
validNow, err := db.ValidateAPIKey(apiKeyStrNow)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(validNow, check.Equals, false)
|
||||
|
||||
validSilly, err := db.ValidateAPIKey("nota.validkey")
|
||||
c.Assert(err, check.NotNil)
|
||||
c.Assert(validSilly, check.Equals, false)
|
||||
|
||||
validWithErr, err := db.ValidateAPIKey("produceerrorkey")
|
||||
c.Assert(err, check.NotNil)
|
||||
c.Assert(validWithErr, check.Equals, false)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (*Suite) TestExpireAPIKey(c *check.C) {
|
||||
nowPlus2 := time.Now().Add(2 * time.Hour)
|
||||
apiKeyStr, apiKey, err := db.CreateAPIKey(&nowPlus2)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(apiKey, check.NotNil)
|
||||
|
||||
valid, err := db.ValidateAPIKey(apiKeyStr)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(valid, check.Equals, true)
|
||||
|
||||
err = db.ExpireAPIKey(apiKey)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(apiKey.Expiration, check.NotNil)
|
||||
|
||||
notValid, err := db.ValidateAPIKey(apiKeyStr)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(notValid, check.Equals, false)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
361
hscontrol/db/db.go
Normal file
361
hscontrol/db/db.go
Normal file
@@ -0,0 +1,361 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
dbVersion = "1"
|
||||
Postgres = "postgres"
|
||||
Sqlite = "sqlite3"
|
||||
)
|
||||
|
||||
var (
|
||||
errValueNotFound = errors.New("not found")
|
||||
errDatabaseNotSupported = errors.New("database type not supported")
|
||||
)
|
||||
|
||||
// KV is a key-value store in a psql table. For future use...
|
||||
// TODO(kradalby): Is this used for anything?
|
||||
type KV struct {
|
||||
Key string
|
||||
Value string
|
||||
}
|
||||
|
||||
type HSDatabase struct {
|
||||
db *gorm.DB
|
||||
notifyStateChan chan<- struct{}
|
||||
notifyPolicyChan chan<- struct{}
|
||||
|
||||
ipAllocationMutex sync.Mutex
|
||||
|
||||
ipPrefixes []netip.Prefix
|
||||
baseDomain string
|
||||
stripEmailDomain bool
|
||||
}
|
||||
|
||||
// TODO(kradalby): assemble this struct from toptions or something typed
|
||||
// rather than arguments.
|
||||
func NewHeadscaleDatabase(
|
||||
dbType, connectionAddr string,
|
||||
stripEmailDomain, debug bool,
|
||||
notifyStateChan chan<- struct{},
|
||||
notifyPolicyChan chan<- struct{},
|
||||
ipPrefixes []netip.Prefix,
|
||||
baseDomain string,
|
||||
) (*HSDatabase, error) {
|
||||
dbConn, err := openDB(dbType, connectionAddr, debug)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db := HSDatabase{
|
||||
db: dbConn,
|
||||
notifyStateChan: notifyStateChan,
|
||||
notifyPolicyChan: notifyPolicyChan,
|
||||
|
||||
ipPrefixes: ipPrefixes,
|
||||
baseDomain: baseDomain,
|
||||
stripEmailDomain: stripEmailDomain,
|
||||
}
|
||||
|
||||
log.Debug().Msgf("database %#v", dbConn)
|
||||
|
||||
if dbType == Postgres {
|
||||
dbConn.Exec(`create extension if not exists "uuid-ossp";`)
|
||||
}
|
||||
|
||||
_ = dbConn.Migrator().RenameTable("namespaces", "users")
|
||||
|
||||
err = dbConn.AutoMigrate(types.User{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = dbConn.Migrator().RenameColumn(&types.Machine{}, "namespace_id", "user_id")
|
||||
_ = dbConn.Migrator().RenameColumn(&types.PreAuthKey{}, "namespace_id", "user_id")
|
||||
|
||||
_ = dbConn.Migrator().RenameColumn(&types.Machine{}, "ip_address", "ip_addresses")
|
||||
_ = dbConn.Migrator().RenameColumn(&types.Machine{}, "name", "hostname")
|
||||
|
||||
// GivenName is used as the primary source of DNS names, make sure
|
||||
// the field is populated and normalized if it was not when the
|
||||
// machine was registered.
|
||||
_ = dbConn.Migrator().RenameColumn(&types.Machine{}, "nickname", "given_name")
|
||||
|
||||
// If the Machine table has a column for registered,
|
||||
// find all occourences of "false" and drop them. Then
|
||||
// remove the column.
|
||||
if dbConn.Migrator().HasColumn(&types.Machine{}, "registered") {
|
||||
log.Info().
|
||||
Msg(`Database has legacy "registered" column in machine, removing...`)
|
||||
|
||||
machines := types.Machines{}
|
||||
if err := dbConn.Not("registered").Find(&machines).Error; err != nil {
|
||||
log.Error().Err(err).Msg("Error accessing db")
|
||||
}
|
||||
|
||||
for _, machine := range machines {
|
||||
log.Info().
|
||||
Str("machine", machine.Hostname).
|
||||
Str("machine_key", machine.MachineKey).
|
||||
Msg("Deleting unregistered machine")
|
||||
if err := dbConn.Delete(&types.Machine{}, machine.ID).Error; err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("machine_key", machine.MachineKey).
|
||||
Msg("Error deleting unregistered machine")
|
||||
}
|
||||
}
|
||||
|
||||
err := dbConn.Migrator().DropColumn(&types.Machine{}, "registered")
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error dropping registered column")
|
||||
}
|
||||
}
|
||||
|
||||
err = dbConn.AutoMigrate(&types.Route{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if dbConn.Migrator().HasColumn(&types.Machine{}, "enabled_routes") {
|
||||
log.Info().Msgf("Database has legacy enabled_routes column in machine, migrating...")
|
||||
|
||||
type MachineAux struct {
|
||||
ID uint64
|
||||
EnabledRoutes types.IPPrefixes
|
||||
}
|
||||
|
||||
machinesAux := []MachineAux{}
|
||||
err := dbConn.Table("machines").Select("id, enabled_routes").Scan(&machinesAux).Error
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Error accessing db")
|
||||
}
|
||||
for _, machine := range machinesAux {
|
||||
for _, prefix := range machine.EnabledRoutes {
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("enabled_route", prefix.String()).
|
||||
Msg("Error parsing enabled_route")
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
err = dbConn.Preload("Machine").
|
||||
Where("machine_id = ? AND prefix = ?", machine.ID, types.IPPrefix(prefix)).
|
||||
First(&types.Route{}).
|
||||
Error
|
||||
if err == nil {
|
||||
log.Info().
|
||||
Str("enabled_route", prefix.String()).
|
||||
Msg("Route already migrated to new table, skipping")
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
route := types.Route{
|
||||
MachineID: machine.ID,
|
||||
Advertised: true,
|
||||
Enabled: true,
|
||||
Prefix: types.IPPrefix(prefix),
|
||||
}
|
||||
if err := dbConn.Create(&route).Error; err != nil {
|
||||
log.Error().Err(err).Msg("Error creating route")
|
||||
} else {
|
||||
log.Info().
|
||||
Uint64("machine_id", route.MachineID).
|
||||
Str("prefix", prefix.String()).
|
||||
Msg("Route migrated")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = dbConn.Migrator().DropColumn(&types.Machine{}, "enabled_routes")
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error dropping enabled_routes column")
|
||||
}
|
||||
}
|
||||
|
||||
err = dbConn.AutoMigrate(&types.Machine{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if dbConn.Migrator().HasColumn(&types.Machine{}, "given_name") {
|
||||
machines := types.Machines{}
|
||||
if err := dbConn.Find(&machines).Error; err != nil {
|
||||
log.Error().Err(err).Msg("Error accessing db")
|
||||
}
|
||||
|
||||
for item, machine := range machines {
|
||||
if machine.GivenName == "" {
|
||||
normalizedHostname, err := util.NormalizeToFQDNRules(
|
||||
machine.Hostname,
|
||||
stripEmailDomain,
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("hostname", machine.Hostname).
|
||||
Err(err).
|
||||
Msg("Failed to normalize machine hostname in DB migration")
|
||||
}
|
||||
|
||||
err = db.RenameMachine(&machines[item], normalizedHostname)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("hostname", machine.Hostname).
|
||||
Err(err).
|
||||
Msg("Failed to save normalized machine name in DB migration")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = dbConn.AutoMigrate(&KV{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = dbConn.AutoMigrate(&types.PreAuthKey{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = dbConn.AutoMigrate(&types.PreAuthKeyACLTag{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = dbConn.Migrator().DropTable("shared_machines")
|
||||
|
||||
err = dbConn.AutoMigrate(&types.APIKey{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO(kradalby): is this needed?
|
||||
err = db.setValue("db_version", dbVersion)
|
||||
|
||||
return &db, err
|
||||
}
|
||||
|
||||
func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
|
||||
log.Debug().Str("type", dbType).Str("connection", connectionAddr).Msg("opening database")
|
||||
|
||||
var dbLogger logger.Interface
|
||||
if debug {
|
||||
dbLogger = logger.Default
|
||||
} else {
|
||||
dbLogger = logger.Default.LogMode(logger.Silent)
|
||||
}
|
||||
|
||||
switch dbType {
|
||||
case Sqlite:
|
||||
db, err := gorm.Open(
|
||||
sqlite.Open(connectionAddr+"?_synchronous=1&_journal_mode=WAL"),
|
||||
&gorm.Config{
|
||||
DisableForeignKeyConstraintWhenMigrating: true,
|
||||
Logger: dbLogger,
|
||||
},
|
||||
)
|
||||
|
||||
db.Exec("PRAGMA foreign_keys=ON")
|
||||
|
||||
// The pure Go SQLite library does not handle locking in
|
||||
// the same way as the C based one and we cant use the gorm
|
||||
// connection pool as of 2022/02/23.
|
||||
sqlDB, _ := db.DB()
|
||||
sqlDB.SetMaxIdleConns(1)
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
sqlDB.SetConnMaxIdleTime(time.Hour)
|
||||
|
||||
return db, err
|
||||
|
||||
case Postgres:
|
||||
return gorm.Open(postgres.Open(connectionAddr), &gorm.Config{
|
||||
DisableForeignKeyConstraintWhenMigrating: true,
|
||||
Logger: dbLogger,
|
||||
})
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf(
|
||||
"database of type %s is not supported: %w",
|
||||
dbType,
|
||||
errDatabaseNotSupported,
|
||||
)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) notifyStateChange() {
|
||||
hsdb.notifyStateChan <- struct{}{}
|
||||
}
|
||||
|
||||
// getValue returns the value for the given key in KV.
|
||||
func (hsdb *HSDatabase) getValue(key string) (string, error) {
|
||||
var row KV
|
||||
if result := hsdb.db.First(&row, "key = ?", key); errors.Is(
|
||||
result.Error,
|
||||
gorm.ErrRecordNotFound,
|
||||
) {
|
||||
return "", errValueNotFound
|
||||
}
|
||||
|
||||
return row.Value, nil
|
||||
}
|
||||
|
||||
// setValue sets value for the given key in KV.
|
||||
func (hsdb *HSDatabase) setValue(key string, value string) error {
|
||||
keyValue := KV{
|
||||
Key: key,
|
||||
Value: value,
|
||||
}
|
||||
|
||||
if _, err := hsdb.getValue(key); err == nil {
|
||||
hsdb.db.Model(&keyValue).Where("key = ?", key).Update("value", value)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := hsdb.db.Create(keyValue).Error; err != nil {
|
||||
return fmt.Errorf("failed to create key value pair in the database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) PingDB(ctx context.Context) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
sqlDB, err := hsdb.db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return sqlDB.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) Close() error {
|
||||
db, err := hsdb.db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return db.Close()
|
||||
}
|
986
hscontrol/db/machine.go
Normal file
986
hscontrol/db/machine.go
Normal file
@@ -0,0 +1,986 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/samber/lo"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
const (
|
||||
MachineGivenNameHashLength = 8
|
||||
MachineGivenNameTrimSize = 2
|
||||
MaxHostnameLength = 255
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMachineNotFound = errors.New("machine not found")
|
||||
ErrMachineRouteIsNotAvailable = errors.New("route is not available on machine")
|
||||
ErrMachineNotFoundRegistrationCache = errors.New(
|
||||
"machine not found in registration cache",
|
||||
)
|
||||
ErrCouldNotConvertMachineInterface = errors.New("failed to convert machine interface")
|
||||
ErrHostnameTooLong = errors.New("hostname too long")
|
||||
ErrDifferentRegisteredUser = errors.New(
|
||||
"machine was previously registered with a different user",
|
||||
)
|
||||
)
|
||||
|
||||
// filterMachinesByACL wrapper function to not have devs pass around locks and maps
|
||||
// related to the application outside of tests.
|
||||
func (hsdb *HSDatabase) filterMachinesByACL(
|
||||
aclRules []tailcfg.FilterRule,
|
||||
currentMachine *types.Machine, peers types.Machines,
|
||||
) types.Machines {
|
||||
return policy.FilterMachinesByACL(currentMachine, peers, aclRules)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ListPeers(machine *types.Machine) (types.Machines, error) {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("Finding direct peers")
|
||||
|
||||
machines := types.Machines{}
|
||||
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("node_key <> ?",
|
||||
machine.NodeKey).Find(&machines).Error; err != nil {
|
||||
log.Error().Err(err).Msg("Error accessing db")
|
||||
|
||||
return types.Machines{}, err
|
||||
}
|
||||
|
||||
sort.Slice(machines, func(i, j int) bool { return machines[i].ID < machines[j].ID })
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", machine.Hostname).
|
||||
Msgf("Found peers: %s", machines.String())
|
||||
|
||||
return machines, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getPeers(
|
||||
aclRules []tailcfg.FilterRule,
|
||||
machine *types.Machine,
|
||||
) (types.Machines, error) {
|
||||
var peers types.Machines
|
||||
var err error
|
||||
|
||||
// If ACLs rules are defined, filter visible host list with the ACLs
|
||||
// else use the classic user scope
|
||||
if len(aclRules) > 0 {
|
||||
var machines []types.Machine
|
||||
machines, err = hsdb.ListMachines()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error retrieving list of machines")
|
||||
|
||||
return types.Machines{}, err
|
||||
}
|
||||
peers = hsdb.filterMachinesByACL(aclRules, machine, machines)
|
||||
} else {
|
||||
peers, err = hsdb.ListPeers(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Cannot fetch peers")
|
||||
|
||||
return types.Machines{}, err
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID })
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("self", machine.Hostname).
|
||||
Str("peers", peers.String()).
|
||||
Msg("Peers returned to caller")
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetValidPeers(
|
||||
aclRules []tailcfg.FilterRule,
|
||||
machine *types.Machine,
|
||||
) (types.Machines, error) {
|
||||
validPeers := make(types.Machines, 0)
|
||||
|
||||
peers, err := hsdb.getPeers(aclRules, machine)
|
||||
if err != nil {
|
||||
return types.Machines{}, err
|
||||
}
|
||||
|
||||
for _, peer := range peers {
|
||||
if !peer.IsExpired() {
|
||||
validPeers = append(validPeers, peer)
|
||||
}
|
||||
}
|
||||
|
||||
return validPeers, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ListMachines() ([]types.Machine, error) {
|
||||
machines := []types.Machine{}
|
||||
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Find(&machines).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return machines, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ListMachinesByGivenName(givenName string) (types.Machines, error) {
|
||||
machines := types.Machines{}
|
||||
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where("given_name = ?", givenName).Find(&machines).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return machines, nil
|
||||
}
|
||||
|
||||
// GetMachine finds a Machine by name and user and returns the Machine struct.
|
||||
func (hsdb *HSDatabase) GetMachine(user string, name string) (*types.Machine, error) {
|
||||
machines, err := hsdb.ListMachinesByUser(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, m := range machines {
|
||||
if m.Hostname == name {
|
||||
return &m, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ErrMachineNotFound
|
||||
}
|
||||
|
||||
// GetMachineByGivenName finds a Machine by given name and user and returns the Machine struct.
|
||||
func (hsdb *HSDatabase) GetMachineByGivenName(
|
||||
user string,
|
||||
givenName string,
|
||||
) (*types.Machine, error) {
|
||||
machines, err := hsdb.ListMachinesByUser(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, m := range machines {
|
||||
if m.GivenName == givenName {
|
||||
return &m, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ErrMachineNotFound
|
||||
}
|
||||
|
||||
// GetMachineByID finds a Machine by ID and returns the Machine struct.
|
||||
func (hsdb *HSDatabase) GetMachineByID(id uint64) (*types.Machine, error) {
|
||||
m := types.Machine{}
|
||||
if result := hsdb.db.Preload("AuthKey").Preload("User").Find(&types.Machine{ID: id}).First(&m); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
// GetMachineByMachineKey finds a Machine by its MachineKey and returns the Machine struct.
|
||||
func (hsdb *HSDatabase) GetMachineByMachineKey(
|
||||
machineKey key.MachinePublic,
|
||||
) (*types.Machine, error) {
|
||||
m := types.Machine{}
|
||||
if result := hsdb.db.Preload("AuthKey").Preload("User").First(&m, "machine_key = ?", util.MachinePublicKeyStripPrefix(machineKey)); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
// GetMachineByNodeKey finds a Machine by its current NodeKey.
|
||||
func (hsdb *HSDatabase) GetMachineByNodeKey(
|
||||
nodeKey key.NodePublic,
|
||||
) (*types.Machine, error) {
|
||||
machine := types.Machine{}
|
||||
if result := hsdb.db.Preload("AuthKey").Preload("User").First(&machine, "node_key = ?",
|
||||
util.NodePublicKeyStripPrefix(nodeKey)); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return &machine, nil
|
||||
}
|
||||
|
||||
// GetMachineByAnyNodeKey finds a Machine by its MachineKey, its current NodeKey or the old one, and returns the Machine struct.
|
||||
func (hsdb *HSDatabase) GetMachineByAnyKey(
|
||||
machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic,
|
||||
) (*types.Machine, error) {
|
||||
machine := types.Machine{}
|
||||
if result := hsdb.db.Preload("AuthKey").Preload("User").First(&machine, "machine_key = ? OR node_key = ? OR node_key = ?",
|
||||
util.MachinePublicKeyStripPrefix(machineKey),
|
||||
util.NodePublicKeyStripPrefix(nodeKey),
|
||||
util.NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return &machine, nil
|
||||
}
|
||||
|
||||
// TODO(kradalby): rename this, it sounds like a mix of getting and setting to db
|
||||
// UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database
|
||||
// and updates it with the latest data from the database.
|
||||
func (hsdb *HSDatabase) UpdateMachineFromDatabase(machine *types.Machine) error {
|
||||
if result := hsdb.db.Find(machine).First(&machine); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetTags takes a Machine struct pointer and update the forced tags.
|
||||
func (hsdb *HSDatabase) SetTags(
|
||||
machine *types.Machine,
|
||||
tags []string,
|
||||
) error {
|
||||
newTags := []string{}
|
||||
for _, tag := range tags {
|
||||
if !util.StringOrPrefixListContains(newTags, tag) {
|
||||
newTags = append(newTags, tag)
|
||||
}
|
||||
}
|
||||
machine.ForcedTags = newTags
|
||||
|
||||
hsdb.notifyPolicyChan <- struct{}{}
|
||||
hsdb.notifyStateChange()
|
||||
|
||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||
return fmt.Errorf("failed to update tags for machine in the database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExpireMachine takes a Machine struct and sets the expire field to now.
|
||||
func (hsdb *HSDatabase) ExpireMachine(machine *types.Machine) error {
|
||||
now := time.Now()
|
||||
machine.Expiry = &now
|
||||
|
||||
hsdb.notifyStateChange()
|
||||
|
||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||
return fmt.Errorf("failed to expire machine in the database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RenameMachine takes a Machine struct and a new GivenName for the machines
|
||||
// and renames it.
|
||||
func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) error {
|
||||
err := util.CheckForFQDNRules(
|
||||
newName,
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("func", "RenameMachine").
|
||||
Str("machine", machine.Hostname).
|
||||
Str("newName", newName).
|
||||
Err(err)
|
||||
|
||||
return err
|
||||
}
|
||||
machine.GivenName = newName
|
||||
|
||||
hsdb.notifyStateChange()
|
||||
|
||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||
return fmt.Errorf("failed to rename machine in the database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RefreshMachine takes a Machine struct and a new expiry time.
|
||||
func (hsdb *HSDatabase) RefreshMachine(machine *types.Machine, expiry time.Time) error {
|
||||
now := time.Now()
|
||||
|
||||
machine.LastSuccessfulUpdate = &now
|
||||
machine.Expiry = &expiry
|
||||
|
||||
hsdb.notifyStateChange()
|
||||
|
||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to refresh machine (update expiration) in the database: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteMachine softs deletes a Machine from the database.
|
||||
func (hsdb *HSDatabase) DeleteMachine(machine *types.Machine) error {
|
||||
err := hsdb.DeleteMachineRoutes(machine)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := hsdb.db.Delete(&machine).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) TouchMachine(machine *types.Machine) error {
|
||||
return hsdb.db.Updates(types.Machine{
|
||||
ID: machine.ID,
|
||||
LastSeen: machine.LastSeen,
|
||||
LastSuccessfulUpdate: machine.LastSuccessfulUpdate,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// HardDeleteMachine hard deletes a Machine from the database.
|
||||
func (hsdb *HSDatabase) HardDeleteMachine(machine *types.Machine) error {
|
||||
err := hsdb.DeleteMachineRoutes(machine)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := hsdb.db.Unscoped().Delete(&machine).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) IsOutdated(machine *types.Machine, lastChange time.Time) bool {
|
||||
if err := hsdb.UpdateMachineFromDatabase(machine); err != nil {
|
||||
// It does not seem meaningful to propagate this error as the end result
|
||||
// will have to be that the machine has to be considered outdated.
|
||||
return true
|
||||
}
|
||||
|
||||
// Get the last update from all headscale users to compare with our nodes
|
||||
// last update.
|
||||
// TODO(kradalby): Only request updates from users where we can talk to nodes
|
||||
// This would mostly be for a bit of performance, and can be calculated based on
|
||||
// ACLs.
|
||||
lastUpdate := machine.CreatedAt
|
||||
if machine.LastSuccessfulUpdate != nil {
|
||||
lastUpdate = *machine.LastSuccessfulUpdate
|
||||
}
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", machine.Hostname).
|
||||
Time("last_successful_update", lastChange).
|
||||
Time("last_state_change", lastUpdate).
|
||||
Msgf("Checking if %s is missing updates", machine.Hostname)
|
||||
|
||||
return lastUpdate.Before(lastChange)
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) RegisterMachineFromAuthCallback(
|
||||
cache *cache.Cache,
|
||||
nodeKeyStr string,
|
||||
userName string,
|
||||
machineExpiry *time.Time,
|
||||
registrationMethod string,
|
||||
) (*types.Machine, error) {
|
||||
nodeKey := key.NodePublic{}
|
||||
err := nodeKey.UnmarshalText([]byte(nodeKeyStr))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("nodeKey", nodeKey.ShortString()).
|
||||
Str("userName", userName).
|
||||
Str("registrationMethod", registrationMethod).
|
||||
Str("expiresAt", fmt.Sprintf("%v", machineExpiry)).
|
||||
Msg("Registering machine from API/CLI or auth callback")
|
||||
|
||||
if machineInterface, ok := cache.Get(util.NodePublicKeyStripPrefix(nodeKey)); ok {
|
||||
if registrationMachine, ok := machineInterface.(types.Machine); ok {
|
||||
user, err := hsdb.GetUser(userName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to find user in register machine from auth callback, %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
// Registration of expired machine with different user
|
||||
if registrationMachine.ID != 0 &&
|
||||
registrationMachine.UserID != user.ID {
|
||||
return nil, ErrDifferentRegisteredUser
|
||||
}
|
||||
|
||||
registrationMachine.UserID = user.ID
|
||||
registrationMachine.RegisterMethod = registrationMethod
|
||||
|
||||
if machineExpiry != nil {
|
||||
registrationMachine.Expiry = machineExpiry
|
||||
}
|
||||
|
||||
machine, err := hsdb.RegisterMachine(
|
||||
registrationMachine,
|
||||
)
|
||||
|
||||
if err == nil {
|
||||
cache.Delete(nodeKeyStr)
|
||||
}
|
||||
|
||||
return machine, err
|
||||
} else {
|
||||
return nil, ErrCouldNotConvertMachineInterface
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ErrMachineNotFoundRegistrationCache
|
||||
}
|
||||
|
||||
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
|
||||
func (hsdb *HSDatabase) RegisterMachine(machine types.Machine,
|
||||
) (*types.Machine, error) {
|
||||
log.Debug().
|
||||
Str("machine", machine.Hostname).
|
||||
Str("machine_key", machine.MachineKey).
|
||||
Str("node_key", machine.NodeKey).
|
||||
Str("user", machine.User.Name).
|
||||
Msg("Registering machine")
|
||||
|
||||
// If the machine exists and we had already IPs for it, we just save it
|
||||
// so we store the machine.Expire and machine.Nodekey that has been set when
|
||||
// adding it to the registrationCache
|
||||
if len(machine.IPAddresses) > 0 {
|
||||
if err := hsdb.db.Save(&machine).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed register existing machine in the database: %w", err)
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", machine.Hostname).
|
||||
Str("machine_key", machine.MachineKey).
|
||||
Str("node_key", machine.NodeKey).
|
||||
Str("user", machine.User.Name).
|
||||
Msg("Machine authorized again")
|
||||
|
||||
return &machine, nil
|
||||
}
|
||||
|
||||
hsdb.ipAllocationMutex.Lock()
|
||||
defer hsdb.ipAllocationMutex.Unlock()
|
||||
|
||||
ips, err := hsdb.getAvailableIPs()
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("Could not find IP for the new machine")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
machine.IPAddresses = ips
|
||||
|
||||
if err := hsdb.db.Save(&machine).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed register(save) machine in the database: %w", err)
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", machine.Hostname).
|
||||
Str("ip", strings.Join(ips.ToStringSlice(), ",")).
|
||||
Msg("Machine registered with the database")
|
||||
|
||||
return &machine, nil
|
||||
}
|
||||
|
||||
// MachineSetNodeKey sets the node key of a machine and saves it to the database.
|
||||
func (hsdb *HSDatabase) MachineSetNodeKey(machine *types.Machine, nodeKey key.NodePublic) error {
|
||||
machine.NodeKey = util.NodePublicKeyStripPrefix(nodeKey)
|
||||
|
||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MachineSetMachineKey sets the machine key of a machine and saves it to the database.
|
||||
func (hsdb *HSDatabase) MachineSetMachineKey(
|
||||
machine *types.Machine,
|
||||
nodeKey key.MachinePublic,
|
||||
) error {
|
||||
machine.MachineKey = util.MachinePublicKeyStripPrefix(nodeKey)
|
||||
|
||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MachineSave saves a machine object to the database, prefer to use a specific save method rather
|
||||
// than this. It is intended to be used when we are changing or.
|
||||
func (hsdb *HSDatabase) MachineSave(machine *types.Machine) error {
|
||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAdvertisedRoutes returns the routes that are be advertised by the given machine.
|
||||
func (hsdb *HSDatabase) GetAdvertisedRoutes(machine *types.Machine) ([]netip.Prefix, error) {
|
||||
routes := types.Routes{}
|
||||
|
||||
err := hsdb.db.
|
||||
Preload("Machine").
|
||||
Where("machine_id = ? AND advertised = ?", machine.ID, true).Find(&routes).Error
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("Could not get advertised routes for machine")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
prefixes := []netip.Prefix{}
|
||||
for _, route := range routes {
|
||||
prefixes = append(prefixes, netip.Prefix(route.Prefix))
|
||||
}
|
||||
|
||||
return prefixes, nil
|
||||
}
|
||||
|
||||
// GetEnabledRoutes returns the routes that are enabled for the machine.
|
||||
func (hsdb *HSDatabase) GetEnabledRoutes(machine *types.Machine) ([]netip.Prefix, error) {
|
||||
routes := types.Routes{}
|
||||
|
||||
err := hsdb.db.
|
||||
Preload("Machine").
|
||||
Where("machine_id = ? AND advertised = ? AND enabled = ?", machine.ID, true, true).
|
||||
Find(&routes).Error
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("Could not get enabled routes for machine")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
prefixes := []netip.Prefix{}
|
||||
for _, route := range routes {
|
||||
prefixes = append(prefixes, netip.Prefix(route.Prefix))
|
||||
}
|
||||
|
||||
return prefixes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) IsRoutesEnabled(machine *types.Machine, routeStr string) bool {
|
||||
route, err := netip.ParsePrefix(routeStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
enabledRoutes, err := hsdb.GetEnabledRoutes(machine)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Could not get enabled routes")
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
for _, enabledRoute := range enabledRoutes {
|
||||
if route == enabledRoute {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// enableRoutes enables new routes based on a list of new routes.
|
||||
func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string) error {
|
||||
newRoutes := make([]netip.Prefix, len(routeStrs))
|
||||
for index, routeStr := range routeStrs {
|
||||
route, err := netip.ParsePrefix(routeStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newRoutes[index] = route
|
||||
}
|
||||
|
||||
advertisedRoutes, err := hsdb.GetAdvertisedRoutes(machine)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, newRoute := range newRoutes {
|
||||
if !util.StringOrPrefixListContains(advertisedRoutes, newRoute) {
|
||||
return fmt.Errorf(
|
||||
"route (%s) is not available on node %s: %w",
|
||||
machine.Hostname,
|
||||
newRoute, ErrMachineRouteIsNotAvailable,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Separate loop so we don't leave things in a half-updated state
|
||||
for _, prefix := range newRoutes {
|
||||
route := types.Route{}
|
||||
err := hsdb.db.Preload("Machine").
|
||||
Where("machine_id = ? AND prefix = ?", machine.ID, types.IPPrefix(prefix)).
|
||||
First(&route).Error
|
||||
if err == nil {
|
||||
route.Enabled = true
|
||||
|
||||
// Mark already as primary if there is only this node offering this subnet
|
||||
// (and is not an exit route)
|
||||
if !route.IsExitRoute() {
|
||||
route.IsPrimary = hsdb.isUniquePrefix(route)
|
||||
}
|
||||
|
||||
err = hsdb.db.Save(&route).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to enable route: %w", err)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("failed to find route: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
hsdb.notifyStateChange()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
|
||||
normalizedHostname, err := util.NormalizeToFQDNRules(
|
||||
suppliedName,
|
||||
hsdb.stripEmailDomain,
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if randomSuffix {
|
||||
// Trim if a hostname will be longer than 63 chars after adding the hash.
|
||||
trimmedHostnameLength := util.LabelHostnameLength - MachineGivenNameHashLength - MachineGivenNameTrimSize
|
||||
if len(normalizedHostname) > trimmedHostnameLength {
|
||||
normalizedHostname = normalizedHostname[:trimmedHostnameLength]
|
||||
}
|
||||
|
||||
suffix, err := util.GenerateRandomStringDNSSafe(MachineGivenNameHashLength)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
normalizedHostname += "-" + suffix
|
||||
}
|
||||
|
||||
return normalizedHostname, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GenerateGivenName(machineKey string, suppliedName string) (string, error) {
|
||||
givenName, err := hsdb.generateGivenName(suppliedName, false)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/
|
||||
machines, err := hsdb.ListMachinesByGivenName(givenName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for _, machine := range machines {
|
||||
if machine.MachineKey != machineKey && machine.GivenName == givenName {
|
||||
postfixedName, err := hsdb.generateGivenName(suppliedName, true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
givenName = postfixedName
|
||||
}
|
||||
}
|
||||
|
||||
return givenName, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Duration) {
|
||||
users, err := hsdb.ListUsers()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error listing users")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
machines, err := hsdb.ListMachinesByUser(user.Name)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("user", user.Name).
|
||||
Msg("Error listing machines in user")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
expiredFound := false
|
||||
for idx, machine := range machines {
|
||||
if machine.IsEphemeral() && machine.LastSeen != nil &&
|
||||
time.Now().
|
||||
After(machine.LastSeen.Add(inactivityThreshhold)) {
|
||||
expiredFound = true
|
||||
log.Info().
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("Ephemeral client removed from database")
|
||||
|
||||
err = hsdb.HardDeleteMachine(&machines[idx])
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("🤮 Cannot delete ephemeral machine from the database")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if expiredFound {
|
||||
hsdb.notifyStateChange()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ExpireExpiredMachines(lastChange time.Time) {
|
||||
users, err := hsdb.ListUsers()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error listing users")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
machines, err := hsdb.ListMachinesByUser(user.Name)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("user", user.Name).
|
||||
Msg("Error listing machines in user")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
expiredFound := false
|
||||
for index, machine := range machines {
|
||||
if machine.IsExpired() &&
|
||||
machine.Expiry.After(lastChange) {
|
||||
expiredFound = true
|
||||
|
||||
err := hsdb.ExpireMachine(&machines[index])
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("name", machine.GivenName).
|
||||
Msg("🤮 Cannot expire machine")
|
||||
} else {
|
||||
log.Info().
|
||||
Str("machine", machine.Hostname).
|
||||
Str("name", machine.GivenName).
|
||||
Msg("Machine successfully expired")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if expiredFound {
|
||||
hsdb.notifyStateChange()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) TailNodes(
|
||||
machines types.Machines,
|
||||
pol *policy.ACLPolicy,
|
||||
dnsConfig *tailcfg.DNSConfig,
|
||||
) ([]*tailcfg.Node, error) {
|
||||
nodes := make([]*tailcfg.Node, len(machines))
|
||||
|
||||
for index, machine := range machines {
|
||||
node, err := hsdb.TailNode(machine, pol, dnsConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodes[index] = node
|
||||
}
|
||||
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
// TailNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes
|
||||
// as per the expected behaviour in the official SaaS.
|
||||
func (hsdb *HSDatabase) TailNode(
|
||||
machine types.Machine,
|
||||
pol *policy.ACLPolicy,
|
||||
dnsConfig *tailcfg.DNSConfig,
|
||||
) (*tailcfg.Node, error) {
|
||||
var nodeKey key.NodePublic
|
||||
err := nodeKey.UnmarshalText([]byte(util.NodePublicKeyEnsurePrefix(machine.NodeKey)))
|
||||
if err != nil {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("node_key", machine.NodeKey).
|
||||
Msgf("Failed to parse node public key from hex")
|
||||
|
||||
return nil, fmt.Errorf("failed to parse node public key: %w", err)
|
||||
}
|
||||
|
||||
var machineKey key.MachinePublic
|
||||
// MachineKey is only used in the legacy protocol
|
||||
if machine.MachineKey != "" {
|
||||
err = machineKey.UnmarshalText(
|
||||
[]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse machine public key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var discoKey key.DiscoPublic
|
||||
if machine.DiscoKey != "" {
|
||||
err := discoKey.UnmarshalText(
|
||||
[]byte(util.DiscoPublicKeyEnsurePrefix(machine.DiscoKey)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse disco public key: %w", err)
|
||||
}
|
||||
} else {
|
||||
discoKey = key.DiscoPublic{}
|
||||
}
|
||||
|
||||
addrs := []netip.Prefix{}
|
||||
for _, machineAddress := range machine.IPAddresses {
|
||||
ip := netip.PrefixFrom(machineAddress, machineAddress.BitLen())
|
||||
addrs = append(addrs, ip)
|
||||
}
|
||||
|
||||
allowedIPs := append(
|
||||
[]netip.Prefix{},
|
||||
addrs...) // we append the node own IP, as it is required by the clients
|
||||
|
||||
primaryRoutes, err := hsdb.GetMachinePrimaryRoutes(&machine)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
primaryPrefixes := primaryRoutes.Prefixes()
|
||||
|
||||
machineRoutes, err := hsdb.GetMachineRoutes(&machine)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, route := range machineRoutes {
|
||||
if route.Enabled && (route.IsPrimary || route.IsExitRoute()) {
|
||||
allowedIPs = append(allowedIPs, netip.Prefix(route.Prefix))
|
||||
}
|
||||
}
|
||||
|
||||
var derp string
|
||||
if machine.HostInfo.NetInfo != nil {
|
||||
derp = fmt.Sprintf("127.3.3.40:%d", machine.HostInfo.NetInfo.PreferredDERP)
|
||||
} else {
|
||||
derp = "127.3.3.40:0" // Zero means disconnected or unknown.
|
||||
}
|
||||
|
||||
var keyExpiry time.Time
|
||||
if machine.Expiry != nil {
|
||||
keyExpiry = *machine.Expiry
|
||||
} else {
|
||||
keyExpiry = time.Time{}
|
||||
}
|
||||
|
||||
var hostname string
|
||||
if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS
|
||||
hostname = fmt.Sprintf(
|
||||
"%s.%s.%s",
|
||||
machine.GivenName,
|
||||
machine.User.Name,
|
||||
hsdb.baseDomain,
|
||||
)
|
||||
if len(hostname) > MaxHostnameLength {
|
||||
return nil, fmt.Errorf(
|
||||
"hostname %q is too long it cannot except 255 ASCII chars: %w",
|
||||
hostname,
|
||||
ErrHostnameTooLong,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
hostname = machine.GivenName
|
||||
}
|
||||
|
||||
hostInfo := machine.GetHostInfo()
|
||||
|
||||
online := machine.IsOnline()
|
||||
|
||||
tags, _ := pol.GetTagsOfMachine(machine, hsdb.stripEmailDomain)
|
||||
tags = lo.Uniq(append(tags, machine.ForcedTags...))
|
||||
|
||||
node := tailcfg.Node{
|
||||
ID: tailcfg.NodeID(machine.ID), // this is the actual ID
|
||||
StableID: tailcfg.StableNodeID(
|
||||
strconv.FormatUint(machine.ID, util.Base10),
|
||||
), // in headscale, unlike tailcontrol server, IDs are permanent
|
||||
Name: hostname,
|
||||
|
||||
User: tailcfg.UserID(machine.UserID),
|
||||
|
||||
Key: nodeKey,
|
||||
KeyExpiry: keyExpiry,
|
||||
|
||||
Machine: machineKey,
|
||||
DiscoKey: discoKey,
|
||||
Addresses: addrs,
|
||||
AllowedIPs: allowedIPs,
|
||||
Endpoints: machine.Endpoints,
|
||||
DERP: derp,
|
||||
Hostinfo: hostInfo.View(),
|
||||
Created: machine.CreatedAt,
|
||||
|
||||
Tags: tags,
|
||||
|
||||
PrimaryRoutes: primaryPrefixes,
|
||||
|
||||
LastSeen: machine.LastSeen,
|
||||
Online: &online,
|
||||
KeepAlive: true,
|
||||
MachineAuthorized: !machine.IsExpired(),
|
||||
|
||||
Capabilities: []string{
|
||||
tailcfg.CapabilityFileSharing,
|
||||
tailcfg.CapabilityAdmin,
|
||||
tailcfg.CapabilitySSH,
|
||||
},
|
||||
}
|
||||
|
||||
return &node, nil
|
||||
}
|
797
hscontrol/db/machine_test.go
Normal file
797
hscontrol/db/machine_test.go
Normal file
@@ -0,0 +1,797 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"gopkg.in/check.v1"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
func (s *Suite) TestGetMachine(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
machine := &types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(machine)
|
||||
|
||||
_, err = db.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMachineByID(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachineByID(0)
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
|
||||
_, err = db.GetMachineByID(0)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMachineByNodeKey(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachineByID(0)
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()),
|
||||
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
|
||||
_, err = db.GetMachineByNodeKey(nodeKey.Public())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachineByID(0)
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
oldNodeKey := key.NewNode()
|
||||
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()),
|
||||
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
|
||||
_, err = db.GetMachineByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestDeleteMachine(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(1),
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
|
||||
err = db.DeleteMachine(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine(user.Name, "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestHardDeleteMachine(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine3",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(1),
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
|
||||
err = db.HardDeleteMachine(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine(user.Name, "testmachine3")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestListPeers(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachineByID(0)
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
for index := 0; index <= 10; index++ {
|
||||
machine := types.Machine{
|
||||
ID: uint64(index),
|
||||
MachineKey: "foo" + strconv.Itoa(index),
|
||||
NodeKey: "bar" + strconv.Itoa(index),
|
||||
DiscoKey: "faa" + strconv.Itoa(index),
|
||||
Hostname: "testmachine" + strconv.Itoa(index),
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
}
|
||||
|
||||
machine0ByID, err := db.GetMachineByID(0)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
peersOfMachine0, err := db.ListPeers(machine0ByID)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(peersOfMachine0), check.Equals, 9)
|
||||
c.Assert(peersOfMachine0[0].Hostname, check.Equals, "testmachine2")
|
||||
c.Assert(peersOfMachine0[5].Hostname, check.Equals, "testmachine7")
|
||||
c.Assert(peersOfMachine0[8].Hostname, check.Equals, "testmachine10")
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||
type base struct {
|
||||
user *types.User
|
||||
key *types.PreAuthKey
|
||||
}
|
||||
|
||||
stor := make([]base, 0)
|
||||
|
||||
for _, name := range []string{"test", "admin"} {
|
||||
user, err := db.CreateUser(name)
|
||||
c.Assert(err, check.IsNil)
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
stor = append(stor, base{user, pak})
|
||||
}
|
||||
|
||||
_, err := db.GetMachineByID(0)
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
for index := 0; index <= 10; index++ {
|
||||
machine := types.Machine{
|
||||
ID: uint64(index),
|
||||
MachineKey: "foo" + strconv.Itoa(index),
|
||||
NodeKey: "bar" + strconv.Itoa(index),
|
||||
DiscoKey: "faa" + strconv.Itoa(index),
|
||||
IPAddresses: types.MachineAddresses{
|
||||
netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))),
|
||||
},
|
||||
Hostname: "testmachine" + strconv.Itoa(index),
|
||||
UserID: stor[index%2].user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(stor[index%2].key.ID),
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
}
|
||||
|
||||
aclPolicy := &policy.ACLPolicy{
|
||||
Groups: map[string][]string{
|
||||
"group:test": {"admin"},
|
||||
},
|
||||
Hosts: map[string]netip.Prefix{},
|
||||
TagOwners: map[string][]string{},
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"admin"},
|
||||
Destinations: []string{"*:*"},
|
||||
},
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"test"},
|
||||
Destinations: []string{"test:*"},
|
||||
},
|
||||
},
|
||||
Tests: []policy.ACLTest{},
|
||||
}
|
||||
|
||||
adminMachine, err := db.GetMachineByID(1)
|
||||
c.Logf("Machine(%v), user: %v", adminMachine.Hostname, adminMachine.User)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
testMachine, err := db.GetMachineByID(2)
|
||||
c.Logf("Machine(%v), user: %v", testMachine.Hostname, testMachine.User)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machines, err := db.ListMachines()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
aclRules, _, err := policy.GenerateFilterRules(aclPolicy, machines, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
peersOfTestMachine := db.filterMachinesByACL(aclRules, testMachine, machines)
|
||||
peersOfAdminMachine := db.filterMachinesByACL(aclRules, adminMachine, machines)
|
||||
|
||||
c.Log(peersOfTestMachine)
|
||||
c.Assert(len(peersOfTestMachine), check.Equals, 9)
|
||||
c.Assert(peersOfTestMachine[0].Hostname, check.Equals, "testmachine1")
|
||||
c.Assert(peersOfTestMachine[1].Hostname, check.Equals, "testmachine3")
|
||||
c.Assert(peersOfTestMachine[3].Hostname, check.Equals, "testmachine5")
|
||||
|
||||
c.Log(peersOfAdminMachine)
|
||||
c.Assert(len(peersOfAdminMachine), check.Equals, 9)
|
||||
c.Assert(peersOfAdminMachine[0].Hostname, check.Equals, "testmachine2")
|
||||
c.Assert(peersOfAdminMachine[2].Hostname, check.Equals, "testmachine4")
|
||||
c.Assert(peersOfAdminMachine[5].Hostname, check.Equals, "testmachine7")
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestExpireMachine(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
machine := &types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
Expiry: &time.Time{},
|
||||
}
|
||||
db.db.Save(machine)
|
||||
|
||||
machineFromDB, err := db.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(machineFromDB, check.NotNil)
|
||||
|
||||
c.Assert(machineFromDB.IsExpired(), check.Equals, false)
|
||||
|
||||
err = db.ExpireMachine(machineFromDB)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(machineFromDB.IsExpired(), check.Equals, true)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(1))
|
||||
}
|
||||
|
||||
func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) {
|
||||
input := types.MachineAddresses([]netip.Addr{
|
||||
netip.MustParseAddr("192.0.2.1"),
|
||||
netip.MustParseAddr("2001:db8::1"),
|
||||
})
|
||||
serialized, err := input.Value()
|
||||
c.Assert(err, check.IsNil)
|
||||
if serial, ok := serialized.(string); ok {
|
||||
c.Assert(serial, check.Equals, "192.0.2.1,2001:db8::1")
|
||||
}
|
||||
|
||||
var deserialized types.MachineAddresses
|
||||
err = deserialized.Scan(serialized)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(deserialized), check.Equals, len(input))
|
||||
for i := range deserialized {
|
||||
c.Assert(deserialized[i], check.Equals, input[i])
|
||||
}
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGenerateGivenName(c *check.C) {
|
||||
user1, err := db.CreateUser("user-1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("user-1", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
machine := &types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "machine-key-1",
|
||||
NodeKey: "node-key-1",
|
||||
DiscoKey: "disco-key-1",
|
||||
Hostname: "hostname-1",
|
||||
GivenName: "hostname-1",
|
||||
UserID: user1.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(machine)
|
||||
|
||||
givenName, err := db.GenerateGivenName("machine-key-2", "hostname-2")
|
||||
comment := check.Commentf("Same user, unique machines, unique hostnames, no conflict")
|
||||
c.Assert(err, check.IsNil, comment)
|
||||
c.Assert(givenName, check.Equals, "hostname-2", comment)
|
||||
|
||||
givenName, err = db.GenerateGivenName("machine-key-1", "hostname-1")
|
||||
comment = check.Commentf("Same user, same machine, same hostname, no conflict")
|
||||
c.Assert(err, check.IsNil, comment)
|
||||
c.Assert(givenName, check.Equals, "hostname-1", comment)
|
||||
|
||||
givenName, err = db.GenerateGivenName("machine-key-2", "hostname-1")
|
||||
comment = check.Commentf("Same user, unique machines, same hostname, conflict")
|
||||
c.Assert(err, check.IsNil, comment)
|
||||
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment)
|
||||
|
||||
givenName, err = db.GenerateGivenName("machine-key-2", "hostname-1")
|
||||
comment = check.Commentf("Unique users, unique machines, same hostname, conflict")
|
||||
c.Assert(err, check.IsNil, comment)
|
||||
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestSetTags(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
machine := &types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(machine)
|
||||
|
||||
// assign simple tags
|
||||
sTags := []string{"tag:test", "tag:foo"}
|
||||
err = db.SetTags(machine, sTags)
|
||||
c.Assert(err, check.IsNil)
|
||||
machine, err = db.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(machine.ForcedTags, check.DeepEquals, types.StringList(sTags))
|
||||
|
||||
// assign duplicat tags, expect no errors but no doubles in DB
|
||||
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
|
||||
err = db.SetTags(machine, eTags)
|
||||
c.Assert(err, check.IsNil)
|
||||
machine, err = db.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(
|
||||
machine.ForcedTags,
|
||||
check.DeepEquals,
|
||||
types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}),
|
||||
)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(4))
|
||||
}
|
||||
|
||||
func TestHeadscale_generateGivenName(t *testing.T) {
|
||||
type args struct {
|
||||
suppliedName string
|
||||
randomSuffix bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
db *HSDatabase
|
||||
args args
|
||||
want *regexp.Regexp
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple machine name generation",
|
||||
db: &HSDatabase{
|
||||
stripEmailDomain: true,
|
||||
},
|
||||
args: args{
|
||||
suppliedName: "testmachine",
|
||||
randomSuffix: false,
|
||||
},
|
||||
want: regexp.MustCompile("^testmachine$"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "machine name with 53 chars",
|
||||
db: &HSDatabase{
|
||||
stripEmailDomain: true,
|
||||
},
|
||||
args: args{
|
||||
suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine",
|
||||
randomSuffix: false,
|
||||
},
|
||||
want: regexp.MustCompile("^testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine$"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "machine name with 63 chars",
|
||||
db: &HSDatabase{
|
||||
stripEmailDomain: true,
|
||||
},
|
||||
args: args{
|
||||
suppliedName: "machineeee12345678901234567890123456789012345678901234567890123",
|
||||
randomSuffix: false,
|
||||
},
|
||||
want: regexp.MustCompile("^machineeee12345678901234567890123456789012345678901234567890123$"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "machine name with 64 chars",
|
||||
db: &HSDatabase{
|
||||
stripEmailDomain: true,
|
||||
},
|
||||
args: args{
|
||||
suppliedName: "machineeee123456789012345678901234567890123456789012345678901234",
|
||||
randomSuffix: false,
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "machine name with 73 chars",
|
||||
db: &HSDatabase{
|
||||
stripEmailDomain: true,
|
||||
},
|
||||
args: args{
|
||||
suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123",
|
||||
randomSuffix: false,
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "machine name with random suffix",
|
||||
db: &HSDatabase{
|
||||
stripEmailDomain: true,
|
||||
},
|
||||
args: args{
|
||||
suppliedName: "test",
|
||||
randomSuffix: true,
|
||||
},
|
||||
want: regexp.MustCompile(fmt.Sprintf("^test-[a-z0-9]{%d}$", MachineGivenNameHashLength)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "machine name with 63 chars with random suffix",
|
||||
db: &HSDatabase{
|
||||
stripEmailDomain: true,
|
||||
},
|
||||
args: args{
|
||||
suppliedName: "machineeee12345678901234567890123456789012345678901234567890123",
|
||||
randomSuffix: true,
|
||||
},
|
||||
want: regexp.MustCompile(fmt.Sprintf("^machineeee1234567890123456789012345678901234567890123-[a-z0-9]{%d}$", MachineGivenNameHashLength)),
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.db.generateGivenName(tt.args.suppliedName, tt.args.randomSuffix)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf(
|
||||
"Headscale.GenerateGivenName() error = %v, wantErr %v",
|
||||
err,
|
||||
tt.wantErr,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if tt.want != nil && !tt.want.MatchString(got) {
|
||||
t.Errorf(
|
||||
"Headscale.GenerateGivenName() = %v, does not match %v",
|
||||
tt.want,
|
||||
got,
|
||||
)
|
||||
}
|
||||
|
||||
if len(got) > util.LabelHostnameLength {
|
||||
t.Errorf(
|
||||
"Headscale.GenerateGivenName() = %v is larger than allowed DNS segment %d",
|
||||
got,
|
||||
util.LabelHostnameLength,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Suite) TestAutoApproveRoutes(c *check.C) {
|
||||
acl := []byte(`
|
||||
{
|
||||
"tagOwners": {
|
||||
"tag:exit": ["test"],
|
||||
},
|
||||
|
||||
"groups": {
|
||||
"group:test": ["test"]
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{"action": "accept", "users": ["*"], "ports": ["*:*"]},
|
||||
],
|
||||
|
||||
"autoApprovers": {
|
||||
"exitNode": ["tag:exit"],
|
||||
"routes": {
|
||||
"10.10.0.0/16": ["group:test"],
|
||||
"10.11.0.0/16": ["test"],
|
||||
}
|
||||
}
|
||||
}
|
||||
`)
|
||||
|
||||
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(pol, check.NotNil)
|
||||
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
|
||||
defaultRouteV4 := netip.MustParsePrefix("0.0.0.0/0")
|
||||
defaultRouteV6 := netip.MustParsePrefix("::/0")
|
||||
route1 := netip.MustParsePrefix("10.10.0.0/16")
|
||||
// Check if a subprefix of an autoapproved route is approved
|
||||
route2 := netip.MustParsePrefix("10.11.0.0/24")
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
|
||||
DiscoKey: "faa",
|
||||
Hostname: "test",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo{
|
||||
RequestTags: []string{"tag:exit"},
|
||||
RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2},
|
||||
},
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
||||
}
|
||||
|
||||
db.db.Save(&machine)
|
||||
|
||||
err = db.ProcessMachineRoutes(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machine0ByID, err := db.GetMachineByID(0)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.EnableAutoApprovedRoutes(pol, machine0ByID)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes, err := db.GetEnabledRoutes(machine0ByID)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(enabledRoutes, check.HasLen, 4)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(4))
|
||||
}
|
||||
|
||||
func TestMachine_canAccess(t *testing.T) {
|
||||
type args struct {
|
||||
filter []tailcfg.FilterRule
|
||||
machine2 *types.Machine
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
machine types.Machine
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "no-rules",
|
||||
machine: types.Machine{
|
||||
IPAddresses: types.MachineAddresses{
|
||||
netip.MustParseAddr("10.0.0.1"),
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
filter: []tailcfg.FilterRule{},
|
||||
machine2: &types.Machine{
|
||||
IPAddresses: types.MachineAddresses{
|
||||
netip.MustParseAddr("10.0.0.2"),
|
||||
},
|
||||
},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard",
|
||||
machine: types.Machine{
|
||||
IPAddresses: types.MachineAddresses{
|
||||
netip.MustParseAddr("10.0.0.1"),
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
filter: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"*"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{
|
||||
IP: "*",
|
||||
Ports: tailcfg.PortRange{
|
||||
First: 0,
|
||||
Last: 65535,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
machine2: &types.Machine{
|
||||
IPAddresses: types.MachineAddresses{
|
||||
netip.MustParseAddr("10.0.0.2"),
|
||||
},
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "explicit-m1-to-m2",
|
||||
machine: types.Machine{
|
||||
IPAddresses: types.MachineAddresses{
|
||||
netip.MustParseAddr("10.0.0.1"),
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
filter: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"10.0.0.1"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{
|
||||
IP: "10.0.0.2",
|
||||
Ports: tailcfg.PortRange{
|
||||
First: 0,
|
||||
Last: 65535,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
machine2: &types.Machine{
|
||||
IPAddresses: types.MachineAddresses{
|
||||
netip.MustParseAddr("10.0.0.2"),
|
||||
},
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "explicit-m2-to-m1",
|
||||
machine: types.Machine{
|
||||
IPAddresses: types.MachineAddresses{
|
||||
netip.MustParseAddr("10.0.0.1"),
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
filter: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"10.0.0.2"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{
|
||||
IP: "10.0.0.1",
|
||||
Ports: tailcfg.PortRange{
|
||||
First: 0,
|
||||
Last: 65535,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
machine2: &types.Machine{
|
||||
IPAddresses: types.MachineAddresses{
|
||||
netip.MustParseAddr("10.0.0.2"),
|
||||
},
|
||||
},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.machine.CanAccess(tt.args.filter, tt.args.machine2); got != tt.want {
|
||||
t.Errorf("Machine.CanAccess() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
196
hscontrol/db/preauth_keys.go
Normal file
196
hscontrol/db/preauth_keys.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPreAuthKeyNotFound = errors.New("AuthKey not found")
|
||||
ErrPreAuthKeyExpired = errors.New("AuthKey expired")
|
||||
ErrSingleUseAuthKeyHasBeenUsed = errors.New("AuthKey has already been used")
|
||||
ErrUserMismatch = errors.New("user mismatch")
|
||||
ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid")
|
||||
)
|
||||
|
||||
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
|
||||
func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||
userName string,
|
||||
reusable bool,
|
||||
ephemeral bool,
|
||||
expiration *time.Time,
|
||||
aclTags []string,
|
||||
) (*types.PreAuthKey, error) {
|
||||
user, err := hsdb.GetUser(userName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, tag := range aclTags {
|
||||
if !strings.HasPrefix(tag, "tag:") {
|
||||
return nil, fmt.Errorf(
|
||||
"%w: '%s' did not begin with 'tag:'",
|
||||
ErrPreAuthKeyACLTagInvalid,
|
||||
tag,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
kstr, err := hsdb.generateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := types.PreAuthKey{
|
||||
Key: kstr,
|
||||
UserID: user.ID,
|
||||
User: *user,
|
||||
Reusable: reusable,
|
||||
Ephemeral: ephemeral,
|
||||
CreatedAt: &now,
|
||||
Expiration: expiration,
|
||||
}
|
||||
|
||||
err = hsdb.db.Transaction(func(db *gorm.DB) error {
|
||||
if err := db.Save(&key).Error; err != nil {
|
||||
return fmt.Errorf("failed to create key in the database: %w", err)
|
||||
}
|
||||
|
||||
if len(aclTags) > 0 {
|
||||
seenTags := map[string]bool{}
|
||||
|
||||
for _, tag := range aclTags {
|
||||
if !seenTags[tag] {
|
||||
if err := db.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to ceate key tag in the database: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
seenTags[tag] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
|
||||
func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) {
|
||||
user, err := hsdb.GetUser(userName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys := []types.PreAuthKey{}
|
||||
if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// GetPreAuthKey returns a PreAuthKey for a given key.
|
||||
func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKey, error) {
|
||||
pak, err := hsdb.ValidatePreAuthKey(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if pak.User.Name != user {
|
||||
return nil, ErrUserMismatch
|
||||
}
|
||||
|
||||
return pak, nil
|
||||
}
|
||||
|
||||
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
|
||||
// does not exist.
|
||||
func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error {
|
||||
return hsdb.db.Transaction(func(db *gorm.DB) error {
|
||||
if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result := db.Unscoped().Delete(pak); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
||||
func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
|
||||
if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UsePreAuthKey marks a PreAuthKey as used.
|
||||
func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error {
|
||||
k.Used = true
|
||||
if err := hsdb.db.Save(k).Error; err != nil {
|
||||
return fmt.Errorf("failed to update key used status in the database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node
|
||||
// If returns no error and a PreAuthKey, it can be used.
|
||||
func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) {
|
||||
pak := types.PreAuthKey{}
|
||||
if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is(
|
||||
result.Error,
|
||||
gorm.ErrRecordNotFound,
|
||||
) {
|
||||
return nil, ErrPreAuthKeyNotFound
|
||||
}
|
||||
|
||||
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
|
||||
return nil, ErrPreAuthKeyExpired
|
||||
}
|
||||
|
||||
if pak.Reusable || pak.Ephemeral { // we don't need to check if has been used before
|
||||
return &pak, nil
|
||||
}
|
||||
|
||||
machines := types.Machines{}
|
||||
if err := hsdb.db.Preload("AuthKey").Where(&types.Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(machines) != 0 || pak.Used {
|
||||
return nil, ErrSingleUseAuthKeyHasBeenUsed
|
||||
}
|
||||
|
||||
return &pak, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) generateKey() (string, error) {
|
||||
size := 24
|
||||
bytes := make([]byte, size)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
213
hscontrol/db/preauth_keys_test.go
Normal file
213
hscontrol/db/preauth_keys_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"gopkg.in/check.v1"
|
||||
)
|
||||
|
||||
func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
||||
_, err := db.CreatePreAuthKey("bogus", true, false, nil, nil)
|
||||
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
key, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
// Did we get a valid key?
|
||||
c.Assert(key.Key, check.NotNil)
|
||||
c.Assert(len(key.Key), check.Equals, 48)
|
||||
|
||||
// Make sure the User association is populated
|
||||
c.Assert(key.User.Name, check.Equals, user.Name)
|
||||
|
||||
_, err = db.ListPreAuthKeys("bogus")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
keys, err := db.ListPreAuthKeys(user.Name)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(keys), check.Equals, 1)
|
||||
|
||||
// Make sure the User association is populated
|
||||
c.Assert((keys)[0].User.Name, check.Equals, user.Name)
|
||||
}
|
||||
|
||||
func (*Suite) TestExpiredPreAuthKey(c *check.C) {
|
||||
user, err := db.CreateUser("test2")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
now := time.Now()
|
||||
pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||
c.Assert(err, check.Equals, ErrPreAuthKeyExpired)
|
||||
c.Assert(key, check.IsNil)
|
||||
}
|
||||
|
||||
func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) {
|
||||
key, err := db.ValidatePreAuthKey("potatoKey")
|
||||
c.Assert(err, check.Equals, ErrPreAuthKeyNotFound)
|
||||
c.Assert(key, check.IsNil)
|
||||
}
|
||||
|
||||
func (*Suite) TestValidateKeyOk(c *check.C) {
|
||||
user, err := db.CreateUser("test3")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(key.ID, check.Equals, pak.ID)
|
||||
}
|
||||
|
||||
func (*Suite) TestAlreadyUsedKey(c *check.C) {
|
||||
user, err := db.CreateUser("test4")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testest",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
|
||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
|
||||
c.Assert(key, check.IsNil)
|
||||
}
|
||||
|
||||
func (*Suite) TestReusableBeingUsedKey(c *check.C) {
|
||||
user, err := db.CreateUser("test5")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testest",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
|
||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(key.ID, check.Equals, pak.ID)
|
||||
}
|
||||
|
||||
func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) {
|
||||
user, err := db.CreateUser("test6")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(key.ID, check.Equals, pak.ID)
|
||||
}
|
||||
|
||||
func (*Suite) TestEphemeralKey(c *check.C) {
|
||||
user, err := db.CreateUser("test7")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
now := time.Now().Add(-time.Second * 30)
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testest",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
LastSeen: &now,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
|
||||
_, err = db.ValidatePreAuthKey(pak.Key)
|
||||
// Ephemeral keys are by definition reusable
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("test7", "testest")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
db.ExpireEphemeralMachines(time.Second * 20)
|
||||
|
||||
// The machine record should have been deleted
|
||||
_, err = db.GetMachine("test7", "testest")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(1))
|
||||
}
|
||||
|
||||
func (*Suite) TestExpirePreauthKey(c *check.C) {
|
||||
user, err := db.CreateUser("test3")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(pak.Expiration, check.IsNil)
|
||||
|
||||
err = db.ExpirePreAuthKey(pak)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(pak.Expiration, check.NotNil)
|
||||
|
||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||
c.Assert(err, check.Equals, ErrPreAuthKeyExpired)
|
||||
c.Assert(key, check.IsNil)
|
||||
}
|
||||
|
||||
func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
|
||||
user, err := db.CreateUser("test6")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
pak.Used = true
|
||||
db.db.Save(&pak)
|
||||
|
||||
_, err = db.ValidatePreAuthKey(pak.Key)
|
||||
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
|
||||
}
|
||||
|
||||
func (*Suite) TestPreAuthKeyACLTags(c *check.C) {
|
||||
user, err := db.CreateUser("test8")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"})
|
||||
c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected
|
||||
|
||||
tags := []string{"tag:test1", "tag:test2"}
|
||||
tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"}
|
||||
_, err = db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
listedPaks, err := db.ListPreAuthKeys("test8")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(listedPaks[0].Proto().AclTags, check.DeepEquals, tags)
|
||||
}
|
457
hscontrol/db/routes.go
Normal file
457
hscontrol/db/routes.go
Normal file
@@ -0,0 +1,457 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var ErrRouteIsNotAvailable = errors.New("route is not available")
|
||||
|
||||
func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := hsdb.db.Preload("Machine").Find(&routes).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetMachineAdvertisedRoutes(machine *types.Machine) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := hsdb.db.
|
||||
Preload("Machine").
|
||||
Where("machine_id = ? AND advertised = true", machine.ID).
|
||||
Find(&routes).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetMachineRoutes(m *types.Machine) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := hsdb.db.
|
||||
Preload("Machine").
|
||||
Where("machine_id = ?", m.ID).
|
||||
Find(&routes).Error
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) {
|
||||
var route types.Route
|
||||
err := hsdb.db.Preload("Machine").First(&route, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &route, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) EnableRoute(id uint64) error {
|
||||
route, err := hsdb.GetRoute(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Tailscale requires both IPv4 and IPv6 exit routes to
|
||||
// be enabled at the same time, as per
|
||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||
if route.IsExitRoute() {
|
||||
return hsdb.enableRoutes(
|
||||
&route.Machine,
|
||||
types.ExitRouteV4.String(),
|
||||
types.ExitRouteV6.String(),
|
||||
)
|
||||
}
|
||||
|
||||
return hsdb.enableRoutes(&route.Machine, netip.Prefix(route.Prefix).String())
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) DisableRoute(id uint64) error {
|
||||
route, err := hsdb.GetRoute(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Tailscale requires both IPv4 and IPv6 exit routes to
|
||||
// be enabled at the same time, as per
|
||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||
if !route.IsExitRoute() {
|
||||
route.Enabled = false
|
||||
route.IsPrimary = false
|
||||
err = hsdb.db.Save(route).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return hsdb.HandlePrimarySubnetFailover()
|
||||
}
|
||||
|
||||
routes, err := hsdb.GetMachineRoutes(&route.Machine)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := range routes {
|
||||
if routes[i].IsExitRoute() {
|
||||
routes[i].Enabled = false
|
||||
routes[i].IsPrimary = false
|
||||
err = hsdb.db.Save(&routes[i]).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return hsdb.HandlePrimarySubnetFailover()
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
|
||||
route, err := hsdb.GetRoute(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Tailscale requires both IPv4 and IPv6 exit routes to
|
||||
// be enabled at the same time, as per
|
||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||
if !route.IsExitRoute() {
|
||||
if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return hsdb.HandlePrimarySubnetFailover()
|
||||
}
|
||||
|
||||
routes, err := hsdb.GetMachineRoutes(&route.Machine)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
routesToDelete := types.Routes{}
|
||||
for _, r := range routes {
|
||||
if r.IsExitRoute() {
|
||||
routesToDelete = append(routesToDelete, r)
|
||||
}
|
||||
}
|
||||
|
||||
if err := hsdb.db.Unscoped().Delete(&routesToDelete).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return hsdb.HandlePrimarySubnetFailover()
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) DeleteMachineRoutes(m *types.Machine) error {
|
||||
routes, err := hsdb.GetMachineRoutes(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := range routes {
|
||||
if err := hsdb.db.Unscoped().Delete(&routes[i]).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return hsdb.HandlePrimarySubnetFailover()
|
||||
}
|
||||
|
||||
// isUniquePrefix returns if there is another machine providing the same route already.
|
||||
func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool {
|
||||
var count int64
|
||||
hsdb.db.
|
||||
Model(&types.Route{}).
|
||||
Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?",
|
||||
route.Prefix,
|
||||
route.MachineID,
|
||||
true, true).Count(&count)
|
||||
|
||||
return count == 0
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, error) {
|
||||
var route types.Route
|
||||
err := hsdb.db.
|
||||
Preload("Machine").
|
||||
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true).
|
||||
First(&route).Error
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
return &route, nil
|
||||
}
|
||||
|
||||
// getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover)
|
||||
// Exit nodes are not considered for this, as they are never marked as Primary.
|
||||
func (hsdb *HSDatabase) GetMachinePrimaryRoutes(m *types.Machine) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := hsdb.db.
|
||||
Preload("Machine").
|
||||
Where("machine_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", m.ID, true, true, true).
|
||||
Find(&routes).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error {
|
||||
currentRoutes := types.Routes{}
|
||||
err := hsdb.db.Where("machine_id = ?", machine.ID).Find(¤tRoutes).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
advertisedRoutes := map[netip.Prefix]bool{}
|
||||
for _, prefix := range machine.HostInfo.RoutableIPs {
|
||||
advertisedRoutes[prefix] = false
|
||||
}
|
||||
|
||||
for pos, route := range currentRoutes {
|
||||
if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok {
|
||||
if !route.Advertised {
|
||||
currentRoutes[pos].Advertised = true
|
||||
err := hsdb.db.Save(¤tRoutes[pos]).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
advertisedRoutes[netip.Prefix(route.Prefix)] = true
|
||||
} else if route.Advertised {
|
||||
currentRoutes[pos].Advertised = false
|
||||
currentRoutes[pos].Enabled = false
|
||||
err := hsdb.db.Save(¤tRoutes[pos]).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for prefix, exists := range advertisedRoutes {
|
||||
if !exists {
|
||||
route := types.Route{
|
||||
MachineID: machine.ID,
|
||||
Prefix: types.IPPrefix(prefix),
|
||||
Advertised: true,
|
||||
Enabled: false,
|
||||
}
|
||||
err := hsdb.db.Create(&route).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error {
|
||||
// first, get all the enabled routes
|
||||
var routes types.Routes
|
||||
err := hsdb.db.
|
||||
Preload("Machine").
|
||||
Where("advertised = ? AND enabled = ?", true, true).
|
||||
Find(&routes).Error
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Error().Err(err).Msg("error getting routes")
|
||||
}
|
||||
|
||||
routesChanged := false
|
||||
for pos, route := range routes {
|
||||
if route.IsExitRoute() {
|
||||
continue
|
||||
}
|
||||
|
||||
if !route.IsPrimary {
|
||||
_, err := hsdb.getPrimaryRoute(netip.Prefix(route.Prefix))
|
||||
if hsdb.isUniquePrefix(route) || errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Info().
|
||||
Str("prefix", netip.Prefix(route.Prefix).String()).
|
||||
Str("machine", route.Machine.GivenName).
|
||||
Msg("Setting primary route")
|
||||
routes[pos].IsPrimary = true
|
||||
err := hsdb.db.Save(&routes[pos]).Error
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error marking route as primary")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
routesChanged = true
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if route.IsPrimary {
|
||||
if route.Machine.IsOnline() {
|
||||
continue
|
||||
}
|
||||
|
||||
// machine offline, find a new primary
|
||||
log.Info().
|
||||
Str("machine", route.Machine.Hostname).
|
||||
Str("prefix", netip.Prefix(route.Prefix).String()).
|
||||
Msgf("machine offline, finding a new primary subnet")
|
||||
|
||||
// find a new primary route
|
||||
var newPrimaryRoutes types.Routes
|
||||
err := hsdb.db.
|
||||
Preload("Machine").
|
||||
Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?",
|
||||
route.Prefix,
|
||||
route.MachineID,
|
||||
true, true).
|
||||
Find(&newPrimaryRoutes).Error
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Error().Err(err).Msg("error finding new primary route")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
var newPrimaryRoute *types.Route
|
||||
for pos, r := range newPrimaryRoutes {
|
||||
if r.Machine.IsOnline() {
|
||||
newPrimaryRoute = &newPrimaryRoutes[pos]
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if newPrimaryRoute == nil {
|
||||
log.Warn().
|
||||
Str("machine", route.Machine.Hostname).
|
||||
Str("prefix", netip.Prefix(route.Prefix).String()).
|
||||
Msgf("no alternative primary route found")
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
log.Info().
|
||||
Str("old_machine", route.Machine.Hostname).
|
||||
Str("prefix", netip.Prefix(route.Prefix).String()).
|
||||
Str("new_machine", newPrimaryRoute.Machine.Hostname).
|
||||
Msgf("found new primary route")
|
||||
|
||||
// disable the old primary route
|
||||
routes[pos].IsPrimary = false
|
||||
err = hsdb.db.Save(&routes[pos]).Error
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error disabling old primary route")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// enable the new primary route
|
||||
newPrimaryRoute.IsPrimary = true
|
||||
err = hsdb.db.Save(&newPrimaryRoute).Error
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error enabling new primary route")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
routesChanged = true
|
||||
}
|
||||
}
|
||||
|
||||
if routesChanged {
|
||||
hsdb.notifyStateChange()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnableAutoApprovedRoutes enables any routes advertised by a machine that match the ACL autoApprovers policy.
|
||||
func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
||||
aclPolicy *policy.ACLPolicy,
|
||||
machine *types.Machine,
|
||||
) error {
|
||||
if len(machine.IPAddresses) == 0 {
|
||||
return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs
|
||||
}
|
||||
|
||||
routes, err := hsdb.GetMachineAdvertisedRoutes(machine)
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("Could not get advertised routes for machine")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
approvedRoutes := types.Routes{}
|
||||
|
||||
for _, advertisedRoute := range routes {
|
||||
if advertisedRoute.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers(
|
||||
netip.Prefix(advertisedRoute.Prefix),
|
||||
)
|
||||
if err != nil {
|
||||
log.Err(err).
|
||||
Str("advertisedRoute", advertisedRoute.String()).
|
||||
Uint64("machineId", machine.ID).
|
||||
Msg("Failed to resolve autoApprovers for advertised route")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
for _, approvedAlias := range routeApprovers {
|
||||
if approvedAlias == machine.User.Name {
|
||||
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
||||
} else {
|
||||
// TODO(kradalby): figure out how to get this to depend on less stuff
|
||||
approvedIps, err := aclPolicy.ExpandAlias(types.Machines{*machine}, approvedAlias, hsdb.stripEmailDomain)
|
||||
if err != nil {
|
||||
log.Err(err).
|
||||
Str("alias", approvedAlias).
|
||||
Msg("Failed to expand alias when processing autoApprovers policy")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// approvedIPs should contain all of machine's IPs if it matches the rule, so check for first
|
||||
if approvedIps.Contains(machine.IPAddresses[0]) {
|
||||
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, approvedRoute := range approvedRoutes {
|
||||
err := hsdb.EnableRoute(uint64(approvedRoute.ID))
|
||||
if err != nil {
|
||||
log.Err(err).
|
||||
Str("approvedRoute", approvedRoute.String()).
|
||||
Uint64("machineId", machine.ID).
|
||||
Msg("Failed to enable approved route")
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
565
hscontrol/db/routes_test.go
Normal file
565
hscontrol/db/routes_test.go
Normal file
@@ -0,0 +1,565 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"gopkg.in/check.v1"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
func (s *Suite) TestGetRoutes(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("test", "test_get_route_machine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
route, err := netip.ParsePrefix("10.0.0.0/24")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
hostInfo := tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{route},
|
||||
}
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "test_get_route_machine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
|
||||
err = db.ProcessMachineRoutes(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
advertisedRoutes, err := db.GetAdvertisedRoutes(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(advertisedRoutes), check.Equals, 1)
|
||||
|
||||
err = db.enableRoutes(&machine, "192.168.0.0/24")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
err = db.enableRoutes(&machine, "10.0.0.0/24")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("test", "test_enable_route_machine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
route, err := netip.ParsePrefix(
|
||||
"10.0.0.0/24",
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
route2, err := netip.ParsePrefix(
|
||||
"150.0.10.0/25",
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
hostInfo := tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{route, route2},
|
||||
}
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "test_enable_route_machine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
|
||||
err = db.ProcessMachineRoutes(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
availableRoutes, err := db.GetAdvertisedRoutes(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(availableRoutes), check.Equals, 2)
|
||||
|
||||
noEnabledRoutes, err := db.GetEnabledRoutes(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(noEnabledRoutes), check.Equals, 0)
|
||||
|
||||
err = db.enableRoutes(&machine, "192.168.0.0/24")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
err = db.enableRoutes(&machine, "10.0.0.0/24")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes, err := db.GetEnabledRoutes(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes), check.Equals, 1)
|
||||
|
||||
// Adding it twice will just let it pass through
|
||||
err = db.enableRoutes(&machine, "10.0.0.0/24")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enableRoutesAfterDoubleApply, err := db.GetEnabledRoutes(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1)
|
||||
|
||||
err = db.enableRoutes(&machine, "150.0.10.0/25")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(3))
|
||||
}
|
||||
|
||||
func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("test", "test_enable_route_machine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
route, err := netip.ParsePrefix(
|
||||
"10.0.0.0/24",
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
route2, err := netip.ParsePrefix(
|
||||
"150.0.10.0/25",
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
hostInfo1 := tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{route, route2},
|
||||
}
|
||||
machine1 := types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "test_enable_route_machine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo1),
|
||||
}
|
||||
db.db.Save(&machine1)
|
||||
|
||||
err = db.ProcessMachineRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.enableRoutes(&machine1, route.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.enableRoutes(&machine1, route2.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
hostInfo2 := tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{route2},
|
||||
}
|
||||
machine2 := types.Machine{
|
||||
ID: 2,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "test_enable_route_machine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo2),
|
||||
}
|
||||
db.db.Save(&machine2)
|
||||
|
||||
err = db.ProcessMachineRoutes(&machine2)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.enableRoutes(&machine2, route2.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes1, err := db.GetEnabledRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes1), check.Equals, 2)
|
||||
|
||||
enabledRoutes2, err := db.GetEnabledRoutes(&machine2)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes2), check.Equals, 1)
|
||||
|
||||
routes, err := db.GetMachinePrimaryRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(routes), check.Equals, 2)
|
||||
|
||||
routes, err = db.GetMachinePrimaryRoutes(&machine2)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(routes), check.Equals, 0)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(3))
|
||||
}
|
||||
|
||||
func (s *Suite) TestSubnetFailover(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("test", "test_enable_route_machine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
prefix, err := netip.ParsePrefix(
|
||||
"10.0.0.0/24",
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
prefix2, err := netip.ParsePrefix(
|
||||
"150.0.10.0/25",
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
hostInfo1 := tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{prefix, prefix2},
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
machine1 := types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "test_enable_route_machine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo1),
|
||||
LastSeen: &now,
|
||||
}
|
||||
db.db.Save(&machine1)
|
||||
|
||||
err = db.ProcessMachineRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.enableRoutes(&machine1, prefix.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.enableRoutes(&machine1, prefix2.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.HandlePrimarySubnetFailover()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes1, err := db.GetEnabledRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes1), check.Equals, 2)
|
||||
|
||||
route, err := db.getPrimaryRoute(prefix)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(route.MachineID, check.Equals, machine1.ID)
|
||||
|
||||
hostInfo2 := tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{prefix2},
|
||||
}
|
||||
machine2 := types.Machine{
|
||||
ID: 2,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "test_enable_route_machine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo2),
|
||||
LastSeen: &now,
|
||||
}
|
||||
db.db.Save(&machine2)
|
||||
|
||||
err = db.ProcessMachineRoutes(&machine2)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.enableRoutes(&machine2, prefix2.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.HandlePrimarySubnetFailover()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes1, err = db.GetEnabledRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes1), check.Equals, 2)
|
||||
|
||||
enabledRoutes2, err := db.GetEnabledRoutes(&machine2)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes2), check.Equals, 1)
|
||||
|
||||
routes, err := db.GetMachinePrimaryRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(routes), check.Equals, 2)
|
||||
|
||||
routes, err = db.GetMachinePrimaryRoutes(&machine2)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(routes), check.Equals, 0)
|
||||
|
||||
// lets make machine1 lastseen 10 mins ago
|
||||
before := now.Add(-10 * time.Minute)
|
||||
machine1.LastSeen = &before
|
||||
err = db.db.Save(&machine1).Error
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.HandlePrimarySubnetFailover()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
routes, err = db.GetMachinePrimaryRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(routes), check.Equals, 1)
|
||||
|
||||
routes, err = db.GetMachinePrimaryRoutes(&machine2)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(routes), check.Equals, 1)
|
||||
|
||||
machine2.HostInfo = types.HostInfo(tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{prefix, prefix2},
|
||||
})
|
||||
err = db.db.Save(&machine2).Error
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.ProcessMachineRoutes(&machine2)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.enableRoutes(&machine2, prefix.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.HandlePrimarySubnetFailover()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
routes, err = db.GetMachinePrimaryRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(routes), check.Equals, 0)
|
||||
|
||||
routes, err = db.GetMachinePrimaryRoutes(&machine2)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(routes), check.Equals, 2)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(6))
|
||||
}
|
||||
|
||||
// TestAllowedIPRoutes tests that the AllowedIPs are correctly set for a node,
|
||||
// including both the primary routes the node is responsible for, and the
|
||||
// exit node routes if enabled.
|
||||
func (s *Suite) TestAllowedIPRoutes(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("test", "test_enable_route_machine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
prefix, err := netip.ParsePrefix(
|
||||
"10.0.0.0/24",
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
prefix2, err := netip.ParsePrefix(
|
||||
"150.0.10.0/25",
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
prefixExitNodeV4, err := netip.ParsePrefix(
|
||||
"0.0.0.0/0",
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
prefixExitNodeV6, err := netip.ParsePrefix(
|
||||
"::/0",
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
hostInfo1 := tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{prefix, prefix2, prefixExitNodeV4, prefixExitNodeV6},
|
||||
}
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
discoKey := key.NewDisco()
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
now := time.Now()
|
||||
machine1 := types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()),
|
||||
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
|
||||
DiscoKey: util.DiscoPublicKeyStripPrefix(discoKey.Public()),
|
||||
Hostname: "test_enable_route_machine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo1),
|
||||
LastSeen: &now,
|
||||
}
|
||||
db.db.Save(&machine1)
|
||||
|
||||
err = db.ProcessMachineRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.enableRoutes(&machine1, prefix.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
// We do not enable this one on purpose to test that it is not enabled
|
||||
// err = db.enableRoutes(&machine1, prefix2.String())
|
||||
// c.Assert(err, check.IsNil)
|
||||
|
||||
routes, err := db.GetMachineRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
for _, route := range routes {
|
||||
if route.IsExitRoute() {
|
||||
err = db.EnableRoute(uint64(route.ID))
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
// We only enable one exit route, so we can test that both are enabled
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
err = db.HandlePrimarySubnetFailover()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes1, err := db.GetEnabledRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes1), check.Equals, 3)
|
||||
|
||||
peer, err := db.TailNode(machine1, &policy.ACLPolicy{}, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(peer.AllowedIPs), check.Equals, 3)
|
||||
|
||||
foundExitNodeV4 := false
|
||||
foundExitNodeV6 := false
|
||||
for _, allowedIP := range peer.AllowedIPs {
|
||||
if allowedIP == prefixExitNodeV4 {
|
||||
foundExitNodeV4 = true
|
||||
}
|
||||
if allowedIP == prefixExitNodeV6 {
|
||||
foundExitNodeV6 = true
|
||||
}
|
||||
}
|
||||
|
||||
c.Assert(foundExitNodeV4, check.Equals, true)
|
||||
c.Assert(foundExitNodeV6, check.Equals, true)
|
||||
|
||||
// Now we disable only one of the exit routes
|
||||
// and we see if both are disabled
|
||||
var exitRouteV4 types.Route
|
||||
for _, route := range routes {
|
||||
if route.IsExitRoute() && netip.Prefix(route.Prefix) == prefixExitNodeV4 {
|
||||
exitRouteV4 = route
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
err = db.DisableRoute(uint64(exitRouteV4.ID))
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes1, err = db.GetEnabledRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes1), check.Equals, 1)
|
||||
|
||||
// and now we delete only one of the exit routes
|
||||
// and we check if both are deleted
|
||||
routes, err = db.GetMachineRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(routes), check.Equals, 4)
|
||||
|
||||
err = db.DeleteRoute(uint64(exitRouteV4.ID))
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
routes, err = db.GetMachineRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(routes), check.Equals, 2)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(2))
|
||||
}
|
||||
|
||||
func (s *Suite) TestDeleteRoutes(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("test", "test_enable_route_machine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
prefix, err := netip.ParsePrefix(
|
||||
"10.0.0.0/24",
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
prefix2, err := netip.ParsePrefix(
|
||||
"150.0.10.0/25",
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
hostInfo1 := tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{prefix, prefix2},
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
machine1 := types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "test_enable_route_machine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo1),
|
||||
LastSeen: &now,
|
||||
}
|
||||
db.db.Save(&machine1)
|
||||
|
||||
err = db.ProcessMachineRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.enableRoutes(&machine1, prefix.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.enableRoutes(&machine1, prefix2.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
routes, err := db.GetMachineRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.DeleteRoute(uint64(routes[0].ID))
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes1, err := db.GetEnabledRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes1), check.Equals, 1)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(2))
|
||||
}
|
74
hscontrol/db/suite_test.go
Normal file
74
hscontrol/db/suite_test.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"gopkg.in/check.v1"
|
||||
)
|
||||
|
||||
func Test(t *testing.T) {
|
||||
check.TestingT(t)
|
||||
}
|
||||
|
||||
var _ = check.Suite(&Suite{})
|
||||
|
||||
type Suite struct{}
|
||||
|
||||
var (
|
||||
tmpDir string
|
||||
db *HSDatabase
|
||||
|
||||
// channelUpdates counts the number of times
|
||||
// either of the channels was notified.
|
||||
channelUpdates int32
|
||||
)
|
||||
|
||||
func (s *Suite) SetUpTest(c *check.C) {
|
||||
atomic.StoreInt32(&channelUpdates, 0)
|
||||
s.ResetDB(c)
|
||||
}
|
||||
|
||||
func (s *Suite) TearDownTest(c *check.C) {
|
||||
os.RemoveAll(tmpDir)
|
||||
}
|
||||
|
||||
func notificationSink(c <-chan struct{}) {
|
||||
for {
|
||||
<-c
|
||||
atomic.AddInt32(&channelUpdates, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Suite) ResetDB(c *check.C) {
|
||||
if len(tmpDir) != 0 {
|
||||
os.RemoveAll(tmpDir)
|
||||
}
|
||||
var err error
|
||||
tmpDir, err = os.MkdirTemp("", "autoygg-client-test")
|
||||
if err != nil {
|
||||
c.Fatal(err)
|
||||
}
|
||||
|
||||
sink := make(chan struct{})
|
||||
|
||||
go notificationSink(sink)
|
||||
|
||||
db, err = NewHeadscaleDatabase(
|
||||
"sqlite3",
|
||||
tmpDir+"/headscale_test.db",
|
||||
false,
|
||||
false,
|
||||
sink,
|
||||
sink,
|
||||
[]netip.Prefix{
|
||||
netip.MustParsePrefix("10.27.0.0/23"),
|
||||
},
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
c.Fatal(err)
|
||||
}
|
||||
}
|
194
hscontrol/db/users.go
Normal file
194
hscontrol/db/users.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUserExists = errors.New("user already exists")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
|
||||
)
|
||||
|
||||
// CreateUser creates a new User. Returns error if could not be created
|
||||
// or another user already exists.
|
||||
func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) {
|
||||
err := util.CheckForFQDNRules(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user := types.User{}
|
||||
if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil {
|
||||
return nil, ErrUserExists
|
||||
}
|
||||
user.Name = name
|
||||
if err := hsdb.db.Create(&user).Error; err != nil {
|
||||
log.Error().
|
||||
Str("func", "CreateUser").
|
||||
Err(err).
|
||||
Msg("Could not create row")
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// DestroyUser destroys a User. Returns error if the User does
|
||||
// not exist or if there are machines associated with it.
|
||||
func (hsdb *HSDatabase) DestroyUser(name string) error {
|
||||
user, err := hsdb.GetUser(name)
|
||||
if err != nil {
|
||||
return ErrUserNotFound
|
||||
}
|
||||
|
||||
machines, err := hsdb.ListMachinesByUser(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(machines) > 0 {
|
||||
return ErrUserStillHasNodes
|
||||
}
|
||||
|
||||
keys, err := hsdb.ListPreAuthKeys(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, key := range keys {
|
||||
err = hsdb.DestroyPreAuthKey(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if result := hsdb.db.Unscoped().Delete(&user); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RenameUser renames a User. Returns error if the User does
|
||||
// not exist or if another User exists with the new name.
|
||||
func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
|
||||
var err error
|
||||
oldUser, err := hsdb.GetUser(oldName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = util.CheckForFQDNRules(newName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = hsdb.GetUser(newName)
|
||||
if err == nil {
|
||||
return ErrUserExists
|
||||
}
|
||||
if !errors.Is(err, ErrUserNotFound) {
|
||||
return err
|
||||
}
|
||||
|
||||
oldUser.Name = newName
|
||||
|
||||
if result := hsdb.db.Save(&oldUser); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUser fetches a user by name.
|
||||
func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) {
|
||||
user := types.User{}
|
||||
if result := hsdb.db.First(&user, "name = ?", name); errors.Is(
|
||||
result.Error,
|
||||
gorm.ErrRecordNotFound,
|
||||
) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// ListUsers gets all the existing users.
|
||||
func (hsdb *HSDatabase) ListUsers() ([]types.User, error) {
|
||||
users := []types.User{}
|
||||
if err := hsdb.db.Find(&users).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// ListMachinesByUser gets all the nodes in a given user.
|
||||
func (hsdb *HSDatabase) ListMachinesByUser(name string) (types.Machines, error) {
|
||||
err := util.CheckForFQDNRules(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user, err := hsdb.GetUser(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
machines := types.Machines{}
|
||||
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Machine{UserID: user.ID}).Find(&machines).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return machines, nil
|
||||
}
|
||||
|
||||
// SetMachineUser assigns a Machine to a user.
|
||||
func (hsdb *HSDatabase) SetMachineUser(machine *types.Machine, username string) error {
|
||||
err := util.CheckForFQDNRules(username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user, err := hsdb.GetUser(username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
machine.User = *user
|
||||
if result := hsdb.db.Save(&machine); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetMapResponseUserProfiles(
|
||||
machine types.Machine,
|
||||
peers types.Machines,
|
||||
) []tailcfg.UserProfile {
|
||||
userMap := make(map[string]types.User)
|
||||
userMap[machine.User.Name] = machine.User
|
||||
for _, peer := range peers {
|
||||
userMap[peer.User.Name] = peer.User // not worth checking if already is there
|
||||
}
|
||||
|
||||
profiles := []tailcfg.UserProfile{}
|
||||
for _, user := range userMap {
|
||||
displayName := user.Name
|
||||
|
||||
if hsdb.baseDomain != "" {
|
||||
displayName = fmt.Sprintf("%s@%s", user.Name, hsdb.baseDomain)
|
||||
}
|
||||
|
||||
profiles = append(profiles,
|
||||
tailcfg.UserProfile{
|
||||
ID: tailcfg.UserID(user.ID),
|
||||
LoginName: user.Name,
|
||||
DisplayName: displayName,
|
||||
})
|
||||
}
|
||||
|
||||
return profiles
|
||||
}
|
277
hscontrol/db/users_test.go
Normal file
277
hscontrol/db/users_test.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"gopkg.in/check.v1"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(user.Name, check.Equals, "test")
|
||||
|
||||
users, err := db.ListUsers()
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(users), check.Equals, 1)
|
||||
|
||||
err = db.DestroyUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetUser("test")
|
||||
c.Assert(err, check.NotNil)
|
||||
}
|
||||
|
||||
func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
||||
err := db.DestroyUser("test")
|
||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.DestroyUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
result := db.db.Preload("User").First(&pak, "key = ?", pak.Key)
|
||||
// destroying a user also deletes all associated preauthkeys
|
||||
c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound)
|
||||
|
||||
user, err = db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err = db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
|
||||
err = db.DestroyUser("test")
|
||||
c.Assert(err, check.Equals, ErrUserStillHasNodes)
|
||||
}
|
||||
|
||||
func (s *Suite) TestRenameUser(c *check.C) {
|
||||
userTest, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(userTest.Name, check.Equals, "test")
|
||||
|
||||
users, err := db.ListUsers()
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(users), check.Equals, 1)
|
||||
|
||||
err = db.RenameUser("test", "test-renamed")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetUser("test")
|
||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||
|
||||
_, err = db.GetUser("test-renamed")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.RenameUser("test-does-not-exit", "test")
|
||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||
|
||||
userTest2, err := db.CreateUser("test2")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(userTest2.Name, check.Equals, "test2")
|
||||
|
||||
err = db.RenameUser("test2", "test-renamed")
|
||||
c.Assert(err, check.Equals, ErrUserExists)
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
||||
userShared1, err := db.CreateUser("shared1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
userShared2, err := db.CreateUser("shared2")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
userShared3, err := db.CreateUser("shared3")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
preAuthKeyShared1, err := db.CreatePreAuthKey(
|
||||
userShared1.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
preAuthKeyShared2, err := db.CreatePreAuthKey(
|
||||
userShared2.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
preAuthKeyShared3, err := db.CreatePreAuthKey(
|
||||
userShared3.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
preAuthKey2Shared1, err := db.CreatePreAuthKey(
|
||||
userShared1.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
machineInShared1 := &types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
Hostname: "test_get_shared_nodes_1",
|
||||
UserID: userShared1.ID,
|
||||
User: *userShared1,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
||||
AuthKeyID: uint(preAuthKeyShared1.ID),
|
||||
}
|
||||
db.db.Save(machineInShared1)
|
||||
|
||||
_, err = db.GetMachine(userShared1.Name, machineInShared1.Hostname)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machineInShared2 := &types.Machine{
|
||||
ID: 2,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Hostname: "test_get_shared_nodes_2",
|
||||
UserID: userShared2.ID,
|
||||
User: *userShared2,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
||||
AuthKeyID: uint(preAuthKeyShared2.ID),
|
||||
}
|
||||
db.db.Save(machineInShared2)
|
||||
|
||||
_, err = db.GetMachine(userShared2.Name, machineInShared2.Hostname)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machineInShared3 := &types.Machine{
|
||||
ID: 3,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Hostname: "test_get_shared_nodes_3",
|
||||
UserID: userShared3.ID,
|
||||
User: *userShared3,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
||||
AuthKeyID: uint(preAuthKeyShared3.ID),
|
||||
}
|
||||
db.db.Save(machineInShared3)
|
||||
|
||||
_, err = db.GetMachine(userShared3.Name, machineInShared3.Hostname)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machine2InShared1 := &types.Machine{
|
||||
ID: 4,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Hostname: "test_get_shared_nodes_4",
|
||||
UserID: userShared1.ID,
|
||||
User: *userShared1,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
|
||||
AuthKeyID: uint(preAuthKey2Shared1.ID),
|
||||
}
|
||||
db.db.Save(machine2InShared1)
|
||||
|
||||
peersOfMachine1InShared1, err := db.getPeers([]tailcfg.FilterRule{}, machineInShared1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
userProfiles := db.GetMapResponseUserProfiles(
|
||||
*machineInShared1,
|
||||
peersOfMachine1InShared1,
|
||||
)
|
||||
|
||||
c.Assert(len(userProfiles), check.Equals, 3)
|
||||
|
||||
found := false
|
||||
for _, userProfiles := range userProfiles {
|
||||
if userProfiles.DisplayName == userShared1.Name {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
c.Assert(found, check.Equals, true)
|
||||
|
||||
found = false
|
||||
for _, userProfile := range userProfiles {
|
||||
if userProfile.DisplayName == userShared2.Name {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
c.Assert(found, check.Equals, true)
|
||||
}
|
||||
|
||||
func (s *Suite) TestSetMachineUser(c *check.C) {
|
||||
oldUser, err := db.CreateUser("old")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
newUser, err := db.CreateUser("new")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: oldUser.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
c.Assert(machine.UserID, check.Equals, oldUser.ID)
|
||||
|
||||
err = db.SetMachineUser(&machine, newUser.Name)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(machine.UserID, check.Equals, newUser.ID)
|
||||
c.Assert(machine.User.Name, check.Equals, newUser.Name)
|
||||
|
||||
err = db.SetMachineUser(&machine, "non-existing-user")
|
||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||
|
||||
err = db.SetMachineUser(&machine, newUser.Name)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(machine.UserID, check.Equals, newUser.ID)
|
||||
c.Assert(machine.User.Name, check.Equals, newUser.Name)
|
||||
}
|
Reference in New Issue
Block a user