mirror of
https://github.com/juanfont/headscale.git
synced 2024-12-22 16:07:34 +00:00
Split code into modules
This is a massive commit that restructures the code into modules: db/ All functions related to modifying the Database types/ All type definitions and methods that can be exclusivly used on these types without dependencies policy/ All Policy related code, now without dependencies on the Database. policy/matcher/ Dedicated code to match machines in a list of FilterRules Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
14e29a7bee
commit
feb15365b5
@ -7,7 +7,7 @@ import (
|
||||
"strconv"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/pterm/pterm"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
@ -277,7 +277,7 @@ func routesToPtables(routes []*v1.Route) pterm.TableData {
|
||||
|
||||
continue
|
||||
}
|
||||
if prefix == hscontrol.ExitRouteV4 || prefix == hscontrol.ExitRouteV6 {
|
||||
if prefix == types.ExitRouteV4 || prefix == types.ExitRouteV6 {
|
||||
isPrimaryStr = "-"
|
||||
} else {
|
||||
isPrimaryStr = strconv.FormatBool(route.IsPrimary)
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"google.golang.org/grpc"
|
||||
@ -41,13 +42,15 @@ func getHeadscaleApp() (*hscontrol.Headscale, error) {
|
||||
|
||||
if cfg.ACL.PolicyPath != "" {
|
||||
aclPath := util.AbsolutePathFromConfigPath(cfg.ACL.PolicyPath)
|
||||
err = app.LoadACLPolicyFromPath(aclPath)
|
||||
pol, err := policy.LoadACLPolicyFromPath(aclPath)
|
||||
if err != nil {
|
||||
log.Fatal().
|
||||
Str("path", aclPath).
|
||||
Err(err).
|
||||
Msg("Could not load the ACL policy")
|
||||
}
|
||||
|
||||
app.ACLPolicy = pol
|
||||
}
|
||||
|
||||
return app, nil
|
||||
|
@ -18,9 +18,6 @@ const (
|
||||
// TODO(juan): remove this once https://github.com/juanfont/headscale/issues/727 is fixed.
|
||||
registrationHoldoff = time.Second * 5
|
||||
reservedResponseHeaderSize = 4
|
||||
RegisterMethodAuthKey = "authkey"
|
||||
RegisterMethodOIDC = "oidc"
|
||||
RegisterMethodCLI = "cli"
|
||||
)
|
||||
|
||||
var ErrRegisterMethodCLIDoesNotSupportExpire = errors.New(
|
||||
@ -56,7 +53,7 @@ func (h *Headscale) HealthHandler(
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.db.pingDB(req.Context()); err != nil {
|
||||
if err := h.db.PingDB(req.Context()); err != nil {
|
||||
respond(err)
|
||||
|
||||
return
|
||||
|
@ -3,6 +3,7 @@ package hscontrol
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/tailcfg"
|
||||
@ -10,13 +11,13 @@ import (
|
||||
|
||||
func (h *Headscale) generateMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
machine *Machine,
|
||||
machine *types.Machine,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
log.Trace().
|
||||
Str("func", "generateMapResponse").
|
||||
Str("machine", mapRequest.Hostinfo.Hostname).
|
||||
Msg("Creating Map response")
|
||||
node, err := h.db.toNode(*machine, h.aclPolicy, h.cfg.BaseDomain, h.cfg.DNSConfig)
|
||||
node, err := h.db.TailNode(*machine, h.ACLPolicy, h.cfg.DNSConfig)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
@ -27,7 +28,7 @@ func (h *Headscale) generateMapResponse(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peers, err := h.db.getValidPeers(h.aclPolicy, h.aclRules, machine)
|
||||
peers, err := h.db.GetValidPeers(h.aclRules, machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
@ -38,9 +39,9 @@ func (h *Headscale) generateMapResponse(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
profiles := h.db.getMapResponseUserProfiles(*machine, peers)
|
||||
profiles := h.db.GetMapResponseUserProfiles(*machine, peers)
|
||||
|
||||
nodePeers, err := h.db.toNodes(peers, h.aclPolicy, h.cfg.BaseDomain, h.cfg.DNSConfig)
|
||||
nodePeers, err := h.db.TailNodes(peers, h.ACLPolicy, h.cfg.DNSConfig)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
166
hscontrol/app.go
166
hscontrol/app.go
@ -23,6 +23,9 @@ import (
|
||||
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
|
||||
"github.com/juanfont/headscale"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/db"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/patrickmn/go-cache"
|
||||
zerolog "github.com/philip-bui/grpc-zerolog"
|
||||
@ -73,7 +76,7 @@ const (
|
||||
// Headscale represents the base app of the service.
|
||||
type Headscale struct {
|
||||
cfg *Config
|
||||
db *HSDatabase
|
||||
db *db.HSDatabase
|
||||
dbString string
|
||||
dbType string
|
||||
dbDebug bool
|
||||
@ -83,7 +86,7 @@ type Headscale struct {
|
||||
DERPMap *tailcfg.DERPMap
|
||||
DERPServer *DERPServer
|
||||
|
||||
aclPolicy *ACLPolicy
|
||||
ACLPolicy *policy.ACLPolicy
|
||||
aclRules []tailcfg.FilterRule
|
||||
sshPolicy *tailcfg.SSHPolicy
|
||||
|
||||
@ -99,6 +102,12 @@ type Headscale struct {
|
||||
|
||||
stateUpdateChan chan struct{}
|
||||
cancelStateUpdateChan chan struct{}
|
||||
|
||||
// TODO(kradalby): Temporary measure to make sure we can update policy
|
||||
// across modules, will be removed when aclRules are no longer stored
|
||||
// globally but generated per node basis.
|
||||
policyUpdateChan chan struct{}
|
||||
cancelPolicyUpdateChan chan struct{}
|
||||
}
|
||||
|
||||
func NewHeadscale(cfg *Config) (*Headscale, error) {
|
||||
@ -119,7 +128,7 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
|
||||
|
||||
var dbString string
|
||||
switch cfg.DBtype {
|
||||
case Postgres:
|
||||
case db.Postgres:
|
||||
dbString = fmt.Sprintf(
|
||||
"host=%s dbname=%s user=%s",
|
||||
cfg.DBhost,
|
||||
@ -142,7 +151,7 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
|
||||
if cfg.DBpass != "" {
|
||||
dbString += fmt.Sprintf(" password=%s", cfg.DBpass)
|
||||
}
|
||||
case Sqlite:
|
||||
case db.Sqlite:
|
||||
dbString = cfg.DBpath
|
||||
default:
|
||||
return nil, errUnsupportedDatabase
|
||||
@ -166,23 +175,28 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
|
||||
|
||||
stateUpdateChan: make(chan struct{}),
|
||||
cancelStateUpdateChan: make(chan struct{}),
|
||||
|
||||
policyUpdateChan: make(chan struct{}),
|
||||
cancelPolicyUpdateChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
go app.watchStateChannel()
|
||||
go app.watchPolicyChannel()
|
||||
|
||||
db, err := NewHeadscaleDatabase(
|
||||
database, err := db.NewHeadscaleDatabase(
|
||||
cfg.DBtype,
|
||||
dbString,
|
||||
cfg.OIDC.StripEmaildomain,
|
||||
app.dbDebug,
|
||||
app.stateUpdateChan,
|
||||
app.policyUpdateChan,
|
||||
cfg.IPPrefixes,
|
||||
cfg.BaseDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
app.db = db
|
||||
app.db = database
|
||||
|
||||
if cfg.OIDC.Issuer != "" {
|
||||
err = app.initOIDC()
|
||||
@ -228,7 +242,7 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) {
|
||||
func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
|
||||
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
||||
for range ticker.C {
|
||||
h.expireEphemeralNodesWorker()
|
||||
h.db.ExpireEphemeralMachines(h.cfg.EphemeralNodeInactivityTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
@ -237,112 +251,20 @@ func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
|
||||
func (h *Headscale) expireExpiredMachines(milliSeconds int64) {
|
||||
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
||||
for range ticker.C {
|
||||
h.expireExpiredMachinesWorker()
|
||||
h.db.ExpireExpiredMachines(h.getLastStateChange())
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) failoverSubnetRoutes(milliSeconds int64) {
|
||||
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
||||
for range ticker.C {
|
||||
err := h.db.handlePrimarySubnetFailover()
|
||||
err := h.db.HandlePrimarySubnetFailover()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to handle primary subnet failover")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) expireEphemeralNodesWorker() {
|
||||
users, err := h.db.ListUsers()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error listing users")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
machines, err := h.db.ListMachinesByUser(user.Name)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("user", user.Name).
|
||||
Msg("Error listing machines in user")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
expiredFound := false
|
||||
for _, machine := range machines {
|
||||
if machine.isEphemeral() && machine.LastSeen != nil &&
|
||||
time.Now().
|
||||
After(machine.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
|
||||
expiredFound = true
|
||||
log.Info().
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("Ephemeral client removed from database")
|
||||
|
||||
err = h.db.db.Unscoped().Delete(machine).Error
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("🤮 Cannot delete ephemeral machine from the database")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if expiredFound {
|
||||
h.setLastStateChangeToNow()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) expireExpiredMachinesWorker() {
|
||||
users, err := h.db.ListUsers()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error listing users")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
machines, err := h.db.ListMachinesByUser(user.Name)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("user", user.Name).
|
||||
Msg("Error listing machines in user")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
expiredFound := false
|
||||
for index, machine := range machines {
|
||||
if machine.isExpired() &&
|
||||
machine.Expiry.After(h.getLastStateChange(user)) {
|
||||
expiredFound = true
|
||||
|
||||
err := h.db.ExpireMachine(&machines[index])
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("machine", machine.Hostname).
|
||||
Str("name", machine.GivenName).
|
||||
Msg("🤮 Cannot expire machine")
|
||||
} else {
|
||||
log.Info().
|
||||
Str("machine", machine.Hostname).
|
||||
Str("name", machine.GivenName).
|
||||
Msg("Machine successfully expired")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if expiredFound {
|
||||
h.setLastStateChangeToNow()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
||||
req interface{},
|
||||
info *grpc.UnaryServerInfo,
|
||||
@ -565,6 +487,8 @@ func (h *Headscale) Serve() error {
|
||||
go h.scheduledDERPMapUpdateWorker(derpMapCancelChannel)
|
||||
}
|
||||
|
||||
// TODO(kradalby): These should have cancel channels and be cleaned
|
||||
// up on shutdown.
|
||||
go h.expireEphemeralNodes(updateInterval)
|
||||
go h.expireExpiredMachines(updateInterval)
|
||||
|
||||
@ -774,10 +698,12 @@ func (h *Headscale) Serve() error {
|
||||
|
||||
if h.cfg.ACL.PolicyPath != "" {
|
||||
aclPath := util.AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath)
|
||||
err := h.LoadACLPolicyFromPath(aclPath)
|
||||
pol, err := policy.LoadACLPolicyFromPath(aclPath)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to reload ACL policy")
|
||||
}
|
||||
|
||||
h.ACLPolicy = pol
|
||||
log.Info().
|
||||
Str("path", aclPath).
|
||||
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
||||
@ -824,12 +750,12 @@ func (h *Headscale) Serve() error {
|
||||
close(h.stateUpdateChan)
|
||||
close(h.cancelStateUpdateChan)
|
||||
|
||||
<-h.cancelPolicyUpdateChan
|
||||
close(h.policyUpdateChan)
|
||||
close(h.cancelPolicyUpdateChan)
|
||||
|
||||
// Close db connections
|
||||
db, err := h.db.db.DB()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get db handle")
|
||||
}
|
||||
err = db.Close()
|
||||
err = h.db.Close()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to close db")
|
||||
}
|
||||
@ -936,6 +862,30 @@ func (h *Headscale) watchStateChannel() {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(kradalby): baby steps, make this more robust.
|
||||
func (h *Headscale) watchPolicyChannel() {
|
||||
for {
|
||||
select {
|
||||
case <-h.policyUpdateChan:
|
||||
machines, err := h.db.ListMachines()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to fetch machines during policy update")
|
||||
}
|
||||
|
||||
rules, sshPolicy, err := policy.GenerateFilterRules(h.ACLPolicy, machines, h.cfg.OIDC.StripEmaildomain)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to update ACL rules")
|
||||
}
|
||||
|
||||
h.aclRules = rules
|
||||
h.sshPolicy = sshPolicy
|
||||
|
||||
case <-h.cancelPolicyUpdateChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) setLastStateChangeToNow() {
|
||||
var err error
|
||||
|
||||
@ -958,7 +908,7 @@ func (h *Headscale) setLastStateChangeToNow() {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) getLastStateChange(users ...User) time.Time {
|
||||
func (h *Headscale) getLastStateChange(users ...types.User) time.Time {
|
||||
times := []time.Time{}
|
||||
|
||||
// getLastStateChange takes a list of users as a "filter", if no users
|
||||
|
480
hscontrol/db/acls_test.go
Normal file
480
hscontrol/db/acls_test.go
Normal file
@ -0,0 +1,480 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"gopkg.in/check.v1"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// TODO(kradalby):
|
||||
// Convert these tests to being non-database dependent and table driven. They are
|
||||
// very verbose, and dont really need the database.
|
||||
|
||||
func (s *Suite) TestSshRules(c *check.C) {
|
||||
envknob.Setenv("HEADSCALE_EXPERIMENTAL_FEATURE_SSH", "1")
|
||||
|
||||
user, err := db.CreateUser("user1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("user1", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
hostInfo := tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "testmachine",
|
||||
RequestTags: []string{"tag:test"},
|
||||
}
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
}
|
||||
err = db.MachineSave(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
aclPolicy := &policy.ACLPolicy{
|
||||
Groups: policy.Groups{
|
||||
"group:test": []string{"user1"},
|
||||
},
|
||||
Hosts: policy.Hosts{
|
||||
"client": netip.PrefixFrom(netip.MustParseAddr("100.64.99.42"), 32),
|
||||
},
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
Destinations: []string{"*:*"},
|
||||
},
|
||||
},
|
||||
SSHs: []policy.SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"group:test"},
|
||||
Destinations: []string{"client"},
|
||||
Users: []string{"autogroup:nonroot"},
|
||||
},
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
Destinations: []string{"client"},
|
||||
Users: []string{"autogroup:nonroot"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, sshPolicy, err := policy.GenerateFilterRules(aclPolicy, types.Machines{}, false)
|
||||
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(sshPolicy, check.NotNil)
|
||||
c.Assert(sshPolicy.Rules, check.HasLen, 2)
|
||||
c.Assert(sshPolicy.Rules[0].SSHUsers, check.HasLen, 1)
|
||||
c.Assert(sshPolicy.Rules[0].Principals, check.HasLen, 1)
|
||||
c.Assert(sshPolicy.Rules[0].Principals[0].UserLogin, check.Matches, "user1")
|
||||
|
||||
c.Assert(sshPolicy.Rules[1].SSHUsers, check.HasLen, 1)
|
||||
c.Assert(sshPolicy.Rules[1].Principals, check.HasLen, 1)
|
||||
c.Assert(sshPolicy.Rules[1].Principals[0].NodeIP, check.Matches, "*")
|
||||
}
|
||||
|
||||
// this test should validate that we can expand a group in a TagOWner section and
|
||||
// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
|
||||
// the tag is matched in the Sources section.
|
||||
func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) {
|
||||
user, err := db.CreateUser("user1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("user1", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
hostInfo := tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "testmachine",
|
||||
RequestTags: []string{"tag:test"},
|
||||
}
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
}
|
||||
err = db.MachineSave(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pol := &policy.ACLPolicy{
|
||||
Groups: policy.Groups{"group:test": []string{"user1", "user2"}},
|
||||
TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}},
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"tag:test"},
|
||||
Destinations: []string{"*:*"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
machines, err := db.ListMachines()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(rules, check.HasLen, 1)
|
||||
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
|
||||
c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.1/32")
|
||||
}
|
||||
|
||||
// this test should validate that we can expand a group in a TagOWner section and
|
||||
// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
|
||||
// the tag is matched in the Destinations section.
|
||||
func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) {
|
||||
user, err := db.CreateUser("user1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("user1", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
hostInfo := tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "testmachine",
|
||||
RequestTags: []string{"tag:test"},
|
||||
}
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "12345",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
}
|
||||
err = db.MachineSave(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pol := &policy.ACLPolicy{
|
||||
Groups: policy.Groups{"group:test": []string{"user1", "user2"}},
|
||||
TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}},
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
Destinations: []string{"tag:test:*"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
machines, err := db.ListMachines()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(rules, check.HasLen, 1)
|
||||
c.Assert(rules[0].DstPorts, check.HasLen, 1)
|
||||
c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
|
||||
}
|
||||
|
||||
// need a test with:
|
||||
// tag on a host that isn't owned by a tag owners. So the user
|
||||
// of the host should be valid.
|
||||
func (s *Suite) TestInvalidTagValidUser(c *check.C) {
|
||||
user, err := db.CreateUser("user1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("user1", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
hostInfo := tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "testmachine",
|
||||
RequestTags: []string{"tag:foo"},
|
||||
}
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "12345",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
}
|
||||
err = db.MachineSave(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pol := &policy.ACLPolicy{
|
||||
TagOwners: policy.TagOwners{"tag:test": []string{"user1"}},
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"user1"},
|
||||
Destinations: []string{"*:*"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
machines, err := db.ListMachines()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(rules, check.HasLen, 1)
|
||||
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
|
||||
c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.1/32")
|
||||
}
|
||||
|
||||
// tag on a host is owned by a tag owner, the tag is valid.
|
||||
// an ACL rule is matching the tag to a user. It should not be valid since the
|
||||
// host should be tied to the tag now.
|
||||
func (s *Suite) TestValidTagInvalidUser(c *check.C) {
|
||||
user, err := db.CreateUser("user1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("user1", "webserver")
|
||||
c.Assert(err, check.NotNil)
|
||||
hostInfo := tailcfg.Hostinfo{
|
||||
OS: "centos",
|
||||
Hostname: "webserver",
|
||||
RequestTags: []string{"tag:webapp"},
|
||||
}
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "12345",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "webserver",
|
||||
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
}
|
||||
err = db.MachineSave(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("user1", "user")
|
||||
hostInfo2 := tailcfg.Hostinfo{
|
||||
OS: "debian",
|
||||
Hostname: "Hostname",
|
||||
}
|
||||
c.Assert(err, check.NotNil)
|
||||
machine = types.Machine{
|
||||
ID: 2,
|
||||
MachineKey: "56789",
|
||||
NodeKey: "bar2",
|
||||
DiscoKey: "faab",
|
||||
Hostname: "user",
|
||||
IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")},
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo(hostInfo2),
|
||||
}
|
||||
err = db.MachineSave(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pol := &policy.ACLPolicy{
|
||||
TagOwners: policy.TagOwners{"tag:webapp": []string{"user1"}},
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"user1"},
|
||||
Destinations: []string{"tag:webapp:80,443"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
machines, err := db.ListMachines()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(rules, check.HasLen, 1)
|
||||
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
|
||||
c.Assert(rules[0].SrcIPs[0], check.Equals, "100.64.0.2/32")
|
||||
c.Assert(rules[0].DstPorts, check.HasLen, 2)
|
||||
c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(80))
|
||||
c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(80))
|
||||
c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")
|
||||
c.Assert(rules[0].DstPorts[1].Ports.First, check.Equals, uint16(443))
|
||||
c.Assert(rules[0].DstPorts[1].Ports.Last, check.Equals, uint16(443))
|
||||
c.Assert(rules[0].DstPorts[1].IP, check.Equals, "100.64.0.1/32")
|
||||
}
|
||||
|
||||
func (s *Suite) TestPortUser(c *check.C) {
|
||||
user, err := db.CreateUser("testuser")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("testuser", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
ips, _ := db.getAvailableIPs()
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "12345",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: ips,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
err = db.MachineSave(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
acl := []byte(`
|
||||
{
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"testuser",
|
||||
],
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`)
|
||||
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(pol, check.NotNil)
|
||||
|
||||
machines, err := db.ListMachines()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(rules, check.NotNil)
|
||||
|
||||
c.Assert(rules, check.HasLen, 1)
|
||||
c.Assert(rules[0].DstPorts, check.HasLen, 1)
|
||||
c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
|
||||
c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
|
||||
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
|
||||
c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip")
|
||||
c.Assert(len(ips), check.Equals, 1)
|
||||
c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32")
|
||||
}
|
||||
|
||||
func (s *Suite) TestPortGroup(c *check.C) {
|
||||
user, err := db.CreateUser("testuser")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("testuser", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
ips, _ := db.getAvailableIPs()
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: ips,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
err = db.MachineSave(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
acl := []byte(`
|
||||
{
|
||||
"groups": {
|
||||
"group:example": [
|
||||
"testuser",
|
||||
],
|
||||
},
|
||||
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"group:example",
|
||||
],
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`)
|
||||
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machines, err := db.ListMachines()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
rules, _, err := policy.GenerateFilterRules(pol, machines, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(rules, check.NotNil)
|
||||
|
||||
c.Assert(rules, check.HasLen, 1)
|
||||
c.Assert(rules[0].DstPorts, check.HasLen, 1)
|
||||
c.Assert(rules[0].DstPorts[0].Ports.First, check.Equals, uint16(0))
|
||||
c.Assert(rules[0].DstPorts[0].Ports.Last, check.Equals, uint16(65535))
|
||||
c.Assert(rules[0].SrcIPs, check.HasLen, 1)
|
||||
c.Assert(rules[0].SrcIPs[0], check.Not(check.Equals), "not an ip")
|
||||
c.Assert(len(ips), check.Equals, 1)
|
||||
c.Assert(rules[0].SrcIPs[0], check.Equals, ips[0].String()+"/32")
|
||||
}
|
@ -3,21 +3,22 @@
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package hscontrol
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"go4.org/netipx"
|
||||
)
|
||||
|
||||
var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP")
|
||||
|
||||
func (hsdb *HSDatabase) getAvailableIPs() (MachineAddresses, error) {
|
||||
var ips MachineAddresses
|
||||
func (hsdb *HSDatabase) getAvailableIPs() (types.MachineAddresses, error) {
|
||||
var ips types.MachineAddresses
|
||||
var err error
|
||||
for _, ipPrefix := range hsdb.ipPrefixes {
|
||||
var ip *netip.Addr
|
||||
@ -68,11 +69,11 @@ func (hsdb *HSDatabase) getUsedIPs() (*netipx.IPSet, error) {
|
||||
// but this was quick to get running and it should be enough
|
||||
// to begin experimenting with a dual stack tailnet.
|
||||
var addressesSlices []string
|
||||
hsdb.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices)
|
||||
hsdb.db.Model(&types.Machine{}).Pluck("ip_addresses", &addressesSlices)
|
||||
|
||||
var ips netipx.IPSetBuilder
|
||||
for _, slice := range addressesSlices {
|
||||
var machineAddresses MachineAddresses
|
||||
var machineAddresses types.MachineAddresses
|
||||
err := machineAddresses.Scan(slice)
|
||||
if err != nil {
|
||||
return &netipx.IPSet{}, fmt.Errorf(
|
@ -1,14 +1,16 @@
|
||||
package hscontrol
|
||||
package db
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"go4.org/netipx"
|
||||
"gopkg.in/check.v1"
|
||||
)
|
||||
|
||||
func (s *Suite) TestGetAvailableIp(c *check.C) {
|
||||
ips, err := app.db.getAvailableIPs()
|
||||
ips, err := db.getAvailableIPs()
|
||||
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
@ -19,32 +21,32 @@ func (s *Suite) TestGetAvailableIp(c *check.C) {
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetUsedIps(c *check.C) {
|
||||
ips, err := app.db.getAvailableIPs()
|
||||
ips, err := db.getAvailableIPs()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
user, err := app.db.CreateUser("test-ip")
|
||||
user, err := db.CreateUser("test-ip")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetMachine("test", "testmachine")
|
||||
_, err = db.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
machine := Machine{
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
IPAddresses: ips,
|
||||
}
|
||||
app.db.db.Save(&machine)
|
||||
db.db.Save(&machine)
|
||||
|
||||
usedIps, err := app.db.getUsedIPs()
|
||||
usedIps, err := db.getUsedIPs()
|
||||
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
@ -56,46 +58,48 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
|
||||
c.Assert(usedIps.Equal(expectedIPSet), check.Equals, true)
|
||||
c.Assert(usedIps.Contains(expected), check.Equals, true)
|
||||
|
||||
machine1, err := app.db.GetMachineByID(0)
|
||||
machine1, err := db.GetMachineByID(0)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(machine1.IPAddresses), check.Equals, 1)
|
||||
c.Assert(machine1.IPAddresses[0], check.Equals, expected)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||
user, err := app.db.CreateUser("test-ip-multi")
|
||||
user, err := db.CreateUser("test-ip-multi")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
for index := 1; index <= 350; index++ {
|
||||
app.db.ipAllocationMutex.Lock()
|
||||
db.ipAllocationMutex.Lock()
|
||||
|
||||
ips, err := app.db.getAvailableIPs()
|
||||
ips, err := db.getAvailableIPs()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetMachine("test", "testmachine")
|
||||
_, err = db.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
machine := Machine{
|
||||
machine := types.Machine{
|
||||
ID: uint64(index),
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
IPAddresses: ips,
|
||||
}
|
||||
app.db.db.Save(&machine)
|
||||
db.db.Save(&machine)
|
||||
|
||||
app.db.ipAllocationMutex.Unlock()
|
||||
db.ipAllocationMutex.Unlock()
|
||||
}
|
||||
|
||||
usedIps, err := app.db.getUsedIPs()
|
||||
usedIps, err := db.getUsedIPs()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
expected0 := netip.MustParseAddr("10.27.0.1")
|
||||
@ -117,7 +121,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||
c.Assert(usedIps.Contains(expected300), check.Equals, true)
|
||||
|
||||
// Check that we can read back the IPs
|
||||
machine1, err := app.db.GetMachineByID(1)
|
||||
machine1, err := db.GetMachineByID(1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(machine1.IPAddresses), check.Equals, 1)
|
||||
c.Assert(
|
||||
@ -126,7 +130,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||
netip.MustParseAddr("10.27.0.1"),
|
||||
)
|
||||
|
||||
machine50, err := app.db.GetMachineByID(50)
|
||||
machine50, err := db.GetMachineByID(50)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(machine50.IPAddresses), check.Equals, 1)
|
||||
c.Assert(
|
||||
@ -136,7 +140,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||
)
|
||||
|
||||
expectedNextIP := netip.MustParseAddr("10.27.1.95")
|
||||
nextIP, err := app.db.getAvailableIPs()
|
||||
nextIP, err := db.getAvailableIPs()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(nextIP), check.Equals, 1)
|
||||
@ -144,15 +148,17 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||
|
||||
// If we call get Available again, we should receive
|
||||
// the same IP, as it has not been reserved.
|
||||
nextIP2, err := app.db.getAvailableIPs()
|
||||
nextIP2, err := db.getAvailableIPs()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(nextIP2), check.Equals, 1)
|
||||
c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String())
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
|
||||
ips, err := app.db.getAvailableIPs()
|
||||
ips, err := db.getAvailableIPs()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
expected := netip.MustParseAddr("10.27.0.1")
|
||||
@ -160,30 +166,32 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
|
||||
c.Assert(len(ips), check.Equals, 1)
|
||||
c.Assert(ips[0].String(), check.Equals, expected.String())
|
||||
|
||||
user, err := app.db.CreateUser("test-ip")
|
||||
user, err := db.CreateUser("test-ip")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetMachine("test", "testmachine")
|
||||
_, err = db.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
machine := Machine{
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
app.db.db.Save(&machine)
|
||||
db.db.Save(&machine)
|
||||
|
||||
ips2, err := app.db.getAvailableIPs()
|
||||
ips2, err := db.getAvailableIPs()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(ips2), check.Equals, 1)
|
||||
c.Assert(ips2[0].String(), check.Equals, expected.String())
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package hscontrol
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@ -6,10 +6,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -19,22 +18,10 @@ const (
|
||||
|
||||
var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey")
|
||||
|
||||
// APIKey describes the datamodel for API keys used to remotely authenticate with
|
||||
// headscale.
|
||||
type APIKey struct {
|
||||
ID uint64 `gorm:"primary_key"`
|
||||
Prefix string `gorm:"uniqueIndex"`
|
||||
Hash []byte
|
||||
|
||||
CreatedAt *time.Time
|
||||
Expiration *time.Time
|
||||
LastSeen *time.Time
|
||||
}
|
||||
|
||||
// CreateAPIKey creates a new ApiKey in a user, and returns it.
|
||||
func (hsdb *HSDatabase) CreateAPIKey(
|
||||
expiration *time.Time,
|
||||
) (string, *APIKey, error) {
|
||||
) (string, *types.APIKey, error) {
|
||||
prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
@ -53,7 +40,7 @@ func (hsdb *HSDatabase) CreateAPIKey(
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
key := APIKey{
|
||||
key := types.APIKey{
|
||||
Prefix: prefix,
|
||||
Hash: hash,
|
||||
Expiration: expiration,
|
||||
@ -67,8 +54,8 @@ func (hsdb *HSDatabase) CreateAPIKey(
|
||||
}
|
||||
|
||||
// ListAPIKeys returns the list of ApiKeys for a user.
|
||||
func (hsdb *HSDatabase) ListAPIKeys() ([]APIKey, error) {
|
||||
keys := []APIKey{}
|
||||
func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) {
|
||||
keys := []types.APIKey{}
|
||||
if err := hsdb.db.Find(&keys).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -77,8 +64,8 @@ func (hsdb *HSDatabase) ListAPIKeys() ([]APIKey, error) {
|
||||
}
|
||||
|
||||
// GetAPIKey returns a ApiKey for a given key.
|
||||
func (hsdb *HSDatabase) GetAPIKey(prefix string) (*APIKey, error) {
|
||||
key := APIKey{}
|
||||
func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) {
|
||||
key := types.APIKey{}
|
||||
if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
@ -87,9 +74,9 @@ func (hsdb *HSDatabase) GetAPIKey(prefix string) (*APIKey, error) {
|
||||
}
|
||||
|
||||
// GetAPIKeyByID returns a ApiKey for a given id.
|
||||
func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*APIKey, error) {
|
||||
key := APIKey{}
|
||||
if result := hsdb.db.Find(&APIKey{ID: id}).First(&key); result.Error != nil {
|
||||
func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) {
|
||||
key := types.APIKey{}
|
||||
if result := hsdb.db.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
@ -98,7 +85,7 @@ func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*APIKey, error) {
|
||||
|
||||
// DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey
|
||||
// does not exist.
|
||||
func (hsdb *HSDatabase) DestroyAPIKey(key APIKey) error {
|
||||
func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error {
|
||||
if result := hsdb.db.Unscoped().Delete(key); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
@ -107,7 +94,7 @@ func (hsdb *HSDatabase) DestroyAPIKey(key APIKey) error {
|
||||
}
|
||||
|
||||
// ExpireAPIKey marks a ApiKey as expired.
|
||||
func (hsdb *HSDatabase) ExpireAPIKey(key *APIKey) error {
|
||||
func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error {
|
||||
if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
@ -136,24 +123,3 @@ func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) {
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (key *APIKey) toProto() *v1.ApiKey {
|
||||
protoKey := v1.ApiKey{
|
||||
Id: key.ID,
|
||||
Prefix: key.Prefix,
|
||||
}
|
||||
|
||||
if key.Expiration != nil {
|
||||
protoKey.Expiration = timestamppb.New(*key.Expiration)
|
||||
}
|
||||
|
||||
if key.CreatedAt != nil {
|
||||
protoKey.CreatedAt = timestamppb.New(*key.CreatedAt)
|
||||
}
|
||||
|
||||
if key.LastSeen != nil {
|
||||
protoKey.LastSeen = timestamppb.New(*key.LastSeen)
|
||||
}
|
||||
|
||||
return &protoKey
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package hscontrol
|
||||
package db
|
||||
|
||||
import (
|
||||
"time"
|
||||
@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
func (*Suite) TestCreateAPIKey(c *check.C) {
|
||||
apiKeyStr, apiKey, err := app.db.CreateAPIKey(nil)
|
||||
apiKeyStr, apiKey, err := db.CreateAPIKey(nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(apiKey, check.NotNil)
|
||||
|
||||
@ -16,74 +16,82 @@ func (*Suite) TestCreateAPIKey(c *check.C) {
|
||||
c.Assert(apiKey.Hash, check.NotNil)
|
||||
c.Assert(apiKeyStr, check.Not(check.Equals), "")
|
||||
|
||||
_, err = app.db.ListAPIKeys()
|
||||
_, err = db.ListAPIKeys()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
keys, err := app.db.ListAPIKeys()
|
||||
keys, err := db.ListAPIKeys()
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(keys), check.Equals, 1)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (*Suite) TestAPIKeyDoesNotExist(c *check.C) {
|
||||
key, err := app.db.GetAPIKey("does-not-exist")
|
||||
key, err := db.GetAPIKey("does-not-exist")
|
||||
c.Assert(err, check.NotNil)
|
||||
c.Assert(key, check.IsNil)
|
||||
}
|
||||
|
||||
func (*Suite) TestValidateAPIKeyOk(c *check.C) {
|
||||
nowPlus2 := time.Now().Add(2 * time.Hour)
|
||||
apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowPlus2)
|
||||
apiKeyStr, apiKey, err := db.CreateAPIKey(&nowPlus2)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(apiKey, check.NotNil)
|
||||
|
||||
valid, err := app.db.ValidateAPIKey(apiKeyStr)
|
||||
valid, err := db.ValidateAPIKey(apiKeyStr)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(valid, check.Equals, true)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (*Suite) TestValidateAPIKeyNotOk(c *check.C) {
|
||||
nowMinus2 := time.Now().Add(time.Duration(-2) * time.Hour)
|
||||
apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowMinus2)
|
||||
apiKeyStr, apiKey, err := db.CreateAPIKey(&nowMinus2)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(apiKey, check.NotNil)
|
||||
|
||||
valid, err := app.db.ValidateAPIKey(apiKeyStr)
|
||||
valid, err := db.ValidateAPIKey(apiKeyStr)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(valid, check.Equals, false)
|
||||
|
||||
now := time.Now()
|
||||
apiKeyStrNow, apiKey, err := app.db.CreateAPIKey(&now)
|
||||
apiKeyStrNow, apiKey, err := db.CreateAPIKey(&now)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(apiKey, check.NotNil)
|
||||
|
||||
validNow, err := app.db.ValidateAPIKey(apiKeyStrNow)
|
||||
validNow, err := db.ValidateAPIKey(apiKeyStrNow)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(validNow, check.Equals, false)
|
||||
|
||||
validSilly, err := app.db.ValidateAPIKey("nota.validkey")
|
||||
validSilly, err := db.ValidateAPIKey("nota.validkey")
|
||||
c.Assert(err, check.NotNil)
|
||||
c.Assert(validSilly, check.Equals, false)
|
||||
|
||||
validWithErr, err := app.db.ValidateAPIKey("produceerrorkey")
|
||||
validWithErr, err := db.ValidateAPIKey("produceerrorkey")
|
||||
c.Assert(err, check.NotNil)
|
||||
c.Assert(validWithErr, check.Equals, false)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (*Suite) TestExpireAPIKey(c *check.C) {
|
||||
nowPlus2 := time.Now().Add(2 * time.Hour)
|
||||
apiKeyStr, apiKey, err := app.db.CreateAPIKey(&nowPlus2)
|
||||
apiKeyStr, apiKey, err := db.CreateAPIKey(&nowPlus2)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(apiKey, check.NotNil)
|
||||
|
||||
valid, err := app.db.ValidateAPIKey(apiKeyStr)
|
||||
valid, err := db.ValidateAPIKey(apiKeyStr)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(valid, check.Equals, true)
|
||||
|
||||
err = app.db.ExpireAPIKey(apiKey)
|
||||
err = db.ExpireAPIKey(apiKey)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(apiKey.Expiration, check.NotNil)
|
||||
|
||||
notValid, err := app.db.ValidateAPIKey(apiKeyStr)
|
||||
notValid, err := db.ValidateAPIKey(apiKeyStr)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(notValid, check.Equals, false)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
@ -1,9 +1,7 @@
|
||||
package hscontrol
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
@ -11,11 +9,12 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -26,7 +25,6 @@ const (
|
||||
|
||||
var (
|
||||
errValueNotFound = errors.New("not found")
|
||||
ErrCannotParsePrefix = errors.New("cannot parse prefix")
|
||||
errDatabaseNotSupported = errors.New("database type not supported")
|
||||
)
|
||||
|
||||
@ -38,8 +36,9 @@ type KV struct {
|
||||
}
|
||||
|
||||
type HSDatabase struct {
|
||||
db *gorm.DB
|
||||
notifyStateChan chan<- struct{}
|
||||
db *gorm.DB
|
||||
notifyStateChan chan<- struct{}
|
||||
notifyPolicyChan chan<- struct{}
|
||||
|
||||
ipAllocationMutex sync.Mutex
|
||||
|
||||
@ -54,6 +53,7 @@ func NewHeadscaleDatabase(
|
||||
dbType, connectionAddr string,
|
||||
stripEmailDomain, debug bool,
|
||||
notifyStateChan chan<- struct{},
|
||||
notifyPolicyChan chan<- struct{},
|
||||
ipPrefixes []netip.Prefix,
|
||||
baseDomain string,
|
||||
) (*HSDatabase, error) {
|
||||
@ -63,8 +63,9 @@ func NewHeadscaleDatabase(
|
||||
}
|
||||
|
||||
db := HSDatabase{
|
||||
db: dbConn,
|
||||
notifyStateChan: notifyStateChan,
|
||||
db: dbConn,
|
||||
notifyStateChan: notifyStateChan,
|
||||
notifyPolicyChan: notifyPolicyChan,
|
||||
|
||||
ipPrefixes: ipPrefixes,
|
||||
baseDomain: baseDomain,
|
||||
@ -79,30 +80,30 @@ func NewHeadscaleDatabase(
|
||||
|
||||
_ = dbConn.Migrator().RenameTable("namespaces", "users")
|
||||
|
||||
err = dbConn.AutoMigrate(User{})
|
||||
err = dbConn.AutoMigrate(types.User{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = dbConn.Migrator().RenameColumn(&Machine{}, "namespace_id", "user_id")
|
||||
_ = dbConn.Migrator().RenameColumn(&PreAuthKey{}, "namespace_id", "user_id")
|
||||
_ = dbConn.Migrator().RenameColumn(&types.Machine{}, "namespace_id", "user_id")
|
||||
_ = dbConn.Migrator().RenameColumn(&types.PreAuthKey{}, "namespace_id", "user_id")
|
||||
|
||||
_ = dbConn.Migrator().RenameColumn(&Machine{}, "ip_address", "ip_addresses")
|
||||
_ = dbConn.Migrator().RenameColumn(&Machine{}, "name", "hostname")
|
||||
_ = dbConn.Migrator().RenameColumn(&types.Machine{}, "ip_address", "ip_addresses")
|
||||
_ = dbConn.Migrator().RenameColumn(&types.Machine{}, "name", "hostname")
|
||||
|
||||
// GivenName is used as the primary source of DNS names, make sure
|
||||
// the field is populated and normalized if it was not when the
|
||||
// machine was registered.
|
||||
_ = dbConn.Migrator().RenameColumn(&Machine{}, "nickname", "given_name")
|
||||
_ = dbConn.Migrator().RenameColumn(&types.Machine{}, "nickname", "given_name")
|
||||
|
||||
// If the Machine table has a column for registered,
|
||||
// find all occourences of "false" and drop them. Then
|
||||
// remove the column.
|
||||
if dbConn.Migrator().HasColumn(&Machine{}, "registered") {
|
||||
if dbConn.Migrator().HasColumn(&types.Machine{}, "registered") {
|
||||
log.Info().
|
||||
Msg(`Database has legacy "registered" column in machine, removing...`)
|
||||
|
||||
machines := Machines{}
|
||||
machines := types.Machines{}
|
||||
if err := dbConn.Not("registered").Find(&machines).Error; err != nil {
|
||||
log.Error().Err(err).Msg("Error accessing db")
|
||||
}
|
||||
@ -112,7 +113,7 @@ func NewHeadscaleDatabase(
|
||||
Str("machine", machine.Hostname).
|
||||
Str("machine_key", machine.MachineKey).
|
||||
Msg("Deleting unregistered machine")
|
||||
if err := dbConn.Delete(&Machine{}, machine.ID).Error; err != nil {
|
||||
if err := dbConn.Delete(&types.Machine{}, machine.ID).Error; err != nil {
|
||||
log.Error().
|
||||
Err(err).
|
||||
Str("machine", machine.Hostname).
|
||||
@ -121,23 +122,23 @@ func NewHeadscaleDatabase(
|
||||
}
|
||||
}
|
||||
|
||||
err := dbConn.Migrator().DropColumn(&Machine{}, "registered")
|
||||
err := dbConn.Migrator().DropColumn(&types.Machine{}, "registered")
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error dropping registered column")
|
||||
}
|
||||
}
|
||||
|
||||
err = dbConn.AutoMigrate(&Route{})
|
||||
err = dbConn.AutoMigrate(&types.Route{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if dbConn.Migrator().HasColumn(&Machine{}, "enabled_routes") {
|
||||
if dbConn.Migrator().HasColumn(&types.Machine{}, "enabled_routes") {
|
||||
log.Info().Msgf("Database has legacy enabled_routes column in machine, migrating...")
|
||||
|
||||
type MachineAux struct {
|
||||
ID uint64
|
||||
EnabledRoutes IPPrefixes
|
||||
EnabledRoutes types.IPPrefixes
|
||||
}
|
||||
|
||||
machinesAux := []MachineAux{}
|
||||
@ -157,8 +158,8 @@ func NewHeadscaleDatabase(
|
||||
}
|
||||
|
||||
err = dbConn.Preload("Machine").
|
||||
Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)).
|
||||
First(&Route{}).
|
||||
Where("machine_id = ? AND prefix = ?", machine.ID, types.IPPrefix(prefix)).
|
||||
First(&types.Route{}).
|
||||
Error
|
||||
if err == nil {
|
||||
log.Info().
|
||||
@ -168,11 +169,11 @@ func NewHeadscaleDatabase(
|
||||
continue
|
||||
}
|
||||
|
||||
route := Route{
|
||||
route := types.Route{
|
||||
MachineID: machine.ID,
|
||||
Advertised: true,
|
||||
Enabled: true,
|
||||
Prefix: IPPrefix(prefix),
|
||||
Prefix: types.IPPrefix(prefix),
|
||||
}
|
||||
if err := dbConn.Create(&route).Error; err != nil {
|
||||
log.Error().Err(err).Msg("Error creating route")
|
||||
@ -185,26 +186,26 @@ func NewHeadscaleDatabase(
|
||||
}
|
||||
}
|
||||
|
||||
err = dbConn.Migrator().DropColumn(&Machine{}, "enabled_routes")
|
||||
err = dbConn.Migrator().DropColumn(&types.Machine{}, "enabled_routes")
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error dropping enabled_routes column")
|
||||
}
|
||||
}
|
||||
|
||||
err = dbConn.AutoMigrate(&Machine{})
|
||||
err = dbConn.AutoMigrate(&types.Machine{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if dbConn.Migrator().HasColumn(&Machine{}, "given_name") {
|
||||
machines := Machines{}
|
||||
if dbConn.Migrator().HasColumn(&types.Machine{}, "given_name") {
|
||||
machines := types.Machines{}
|
||||
if err := dbConn.Find(&machines).Error; err != nil {
|
||||
log.Error().Err(err).Msg("Error accessing db")
|
||||
}
|
||||
|
||||
for item, machine := range machines {
|
||||
if machine.GivenName == "" {
|
||||
normalizedHostname, err := NormalizeToFQDNRules(
|
||||
normalizedHostname, err := util.NormalizeToFQDNRules(
|
||||
machine.Hostname,
|
||||
stripEmailDomain,
|
||||
)
|
||||
@ -233,19 +234,19 @@ func NewHeadscaleDatabase(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = dbConn.AutoMigrate(&PreAuthKey{})
|
||||
err = dbConn.AutoMigrate(&types.PreAuthKey{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = dbConn.AutoMigrate(&PreAuthKeyACLTag{})
|
||||
err = dbConn.AutoMigrate(&types.PreAuthKeyACLTag{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = dbConn.Migrator().DropTable("shared_machines")
|
||||
|
||||
err = dbConn.AutoMigrate(&APIKey{})
|
||||
err = dbConn.AutoMigrate(&types.APIKey{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -339,7 +340,7 @@ func (hsdb *HSDatabase) setValue(key string, value string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) pingDB(ctx context.Context) error {
|
||||
func (hsdb *HSDatabase) PingDB(ctx context.Context) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second)
|
||||
defer cancel()
|
||||
sqlDB, err := hsdb.db.DB()
|
||||
@ -350,97 +351,11 @@ func (hsdb *HSDatabase) pingDB(ctx context.Context) error {
|
||||
return sqlDB.PingContext(ctx)
|
||||
}
|
||||
|
||||
// This is a "wrapper" type around tailscales
|
||||
// Hostinfo to allow us to add database "serialization"
|
||||
// methods. This allows us to use a typed values throughout
|
||||
// the code and not have to marshal/unmarshal and error
|
||||
// check all over the code.
|
||||
type HostInfo tailcfg.Hostinfo
|
||||
|
||||
func (hi *HostInfo) Scan(destination interface{}) error {
|
||||
switch value := destination.(type) {
|
||||
case []byte:
|
||||
return json.Unmarshal(value, hi)
|
||||
|
||||
case string:
|
||||
return json.Unmarshal([]byte(value), hi)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
|
||||
func (hsdb *HSDatabase) Close() error {
|
||||
db, err := hsdb.db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Value return json value, implement driver.Valuer interface.
|
||||
func (hi HostInfo) Value() (driver.Value, error) {
|
||||
bytes, err := json.Marshal(hi)
|
||||
|
||||
return string(bytes), err
|
||||
}
|
||||
|
||||
type IPPrefix netip.Prefix
|
||||
|
||||
func (i *IPPrefix) Scan(destination interface{}) error {
|
||||
switch value := destination.(type) {
|
||||
case string:
|
||||
prefix, err := netip.ParsePrefix(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*i = IPPrefix(prefix)
|
||||
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected data type %T", ErrCannotParsePrefix, destination)
|
||||
}
|
||||
}
|
||||
|
||||
// Value return json value, implement driver.Valuer interface.
|
||||
func (i IPPrefix) Value() (driver.Value, error) {
|
||||
prefixStr := netip.Prefix(i).String()
|
||||
|
||||
return prefixStr, nil
|
||||
}
|
||||
|
||||
type IPPrefixes []netip.Prefix
|
||||
|
||||
func (i *IPPrefixes) Scan(destination interface{}) error {
|
||||
switch value := destination.(type) {
|
||||
case []byte:
|
||||
return json.Unmarshal(value, i)
|
||||
|
||||
case string:
|
||||
return json.Unmarshal([]byte(value), i)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
|
||||
}
|
||||
}
|
||||
|
||||
// Value return json value, implement driver.Valuer interface.
|
||||
func (i IPPrefixes) Value() (driver.Value, error) {
|
||||
bytes, err := json.Marshal(i)
|
||||
|
||||
return string(bytes), err
|
||||
}
|
||||
|
||||
type StringList []string
|
||||
|
||||
func (i *StringList) Scan(destination interface{}) error {
|
||||
switch value := destination.(type) {
|
||||
case []byte:
|
||||
return json.Unmarshal(value, i)
|
||||
|
||||
case string:
|
||||
return json.Unmarshal([]byte(value), i)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
|
||||
}
|
||||
}
|
||||
|
||||
// Value return json value, implement driver.Valuer interface.
|
||||
func (i StringList) Value() (driver.Value, error) {
|
||||
bytes, err := json.Marshal(i)
|
||||
|
||||
return string(bytes), err
|
||||
|
||||
return db.Close()
|
||||
}
|
File diff suppressed because it is too large
Load Diff
797
hscontrol/db/machine_test.go
Normal file
797
hscontrol/db/machine_test.go
Normal file
@ -0,0 +1,797 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"gopkg.in/check.v1"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
func (s *Suite) TestGetMachine(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
machine := &types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(machine)
|
||||
|
||||
_, err = db.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMachineByID(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachineByID(0)
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
|
||||
_, err = db.GetMachineByID(0)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMachineByNodeKey(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachineByID(0)
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()),
|
||||
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
|
||||
_, err = db.GetMachineByNodeKey(nodeKey.Public())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachineByID(0)
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
oldNodeKey := key.NewNode()
|
||||
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()),
|
||||
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
|
||||
_, err = db.GetMachineByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestDeleteMachine(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(1),
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
|
||||
err = db.DeleteMachine(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine(user.Name, "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestHardDeleteMachine(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine3",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(1),
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
|
||||
err = db.HardDeleteMachine(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine(user.Name, "testmachine3")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestListPeers(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachineByID(0)
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
for index := 0; index <= 10; index++ {
|
||||
machine := types.Machine{
|
||||
ID: uint64(index),
|
||||
MachineKey: "foo" + strconv.Itoa(index),
|
||||
NodeKey: "bar" + strconv.Itoa(index),
|
||||
DiscoKey: "faa" + strconv.Itoa(index),
|
||||
Hostname: "testmachine" + strconv.Itoa(index),
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
}
|
||||
|
||||
machine0ByID, err := db.GetMachineByID(0)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
peersOfMachine0, err := db.ListPeers(machine0ByID)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(peersOfMachine0), check.Equals, 9)
|
||||
c.Assert(peersOfMachine0[0].Hostname, check.Equals, "testmachine2")
|
||||
c.Assert(peersOfMachine0[5].Hostname, check.Equals, "testmachine7")
|
||||
c.Assert(peersOfMachine0[8].Hostname, check.Equals, "testmachine10")
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||
type base struct {
|
||||
user *types.User
|
||||
key *types.PreAuthKey
|
||||
}
|
||||
|
||||
stor := make([]base, 0)
|
||||
|
||||
for _, name := range []string{"test", "admin"} {
|
||||
user, err := db.CreateUser(name)
|
||||
c.Assert(err, check.IsNil)
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
stor = append(stor, base{user, pak})
|
||||
}
|
||||
|
||||
_, err := db.GetMachineByID(0)
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
for index := 0; index <= 10; index++ {
|
||||
machine := types.Machine{
|
||||
ID: uint64(index),
|
||||
MachineKey: "foo" + strconv.Itoa(index),
|
||||
NodeKey: "bar" + strconv.Itoa(index),
|
||||
DiscoKey: "faa" + strconv.Itoa(index),
|
||||
IPAddresses: types.MachineAddresses{
|
||||
netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))),
|
||||
},
|
||||
Hostname: "testmachine" + strconv.Itoa(index),
|
||||
UserID: stor[index%2].user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(stor[index%2].key.ID),
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
}
|
||||
|
||||
aclPolicy := &policy.ACLPolicy{
|
||||
Groups: map[string][]string{
|
||||
"group:test": {"admin"},
|
||||
},
|
||||
Hosts: map[string]netip.Prefix{},
|
||||
TagOwners: map[string][]string{},
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"admin"},
|
||||
Destinations: []string{"*:*"},
|
||||
},
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"test"},
|
||||
Destinations: []string{"test:*"},
|
||||
},
|
||||
},
|
||||
Tests: []policy.ACLTest{},
|
||||
}
|
||||
|
||||
adminMachine, err := db.GetMachineByID(1)
|
||||
c.Logf("Machine(%v), user: %v", adminMachine.Hostname, adminMachine.User)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
testMachine, err := db.GetMachineByID(2)
|
||||
c.Logf("Machine(%v), user: %v", testMachine.Hostname, testMachine.User)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machines, err := db.ListMachines()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
aclRules, _, err := policy.GenerateFilterRules(aclPolicy, machines, false)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
peersOfTestMachine := db.filterMachinesByACL(aclRules, testMachine, machines)
|
||||
peersOfAdminMachine := db.filterMachinesByACL(aclRules, adminMachine, machines)
|
||||
|
||||
c.Log(peersOfTestMachine)
|
||||
c.Assert(len(peersOfTestMachine), check.Equals, 9)
|
||||
c.Assert(peersOfTestMachine[0].Hostname, check.Equals, "testmachine1")
|
||||
c.Assert(peersOfTestMachine[1].Hostname, check.Equals, "testmachine3")
|
||||
c.Assert(peersOfTestMachine[3].Hostname, check.Equals, "testmachine5")
|
||||
|
||||
c.Log(peersOfAdminMachine)
|
||||
c.Assert(len(peersOfAdminMachine), check.Equals, 9)
|
||||
c.Assert(peersOfAdminMachine[0].Hostname, check.Equals, "testmachine2")
|
||||
c.Assert(peersOfAdminMachine[2].Hostname, check.Equals, "testmachine4")
|
||||
c.Assert(peersOfAdminMachine[5].Hostname, check.Equals, "testmachine7")
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestExpireMachine(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
machine := &types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
Expiry: &time.Time{},
|
||||
}
|
||||
db.db.Save(machine)
|
||||
|
||||
machineFromDB, err := db.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(machineFromDB, check.NotNil)
|
||||
|
||||
c.Assert(machineFromDB.IsExpired(), check.Equals, false)
|
||||
|
||||
err = db.ExpireMachine(machineFromDB)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(machineFromDB.IsExpired(), check.Equals, true)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(1))
|
||||
}
|
||||
|
||||
func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) {
|
||||
input := types.MachineAddresses([]netip.Addr{
|
||||
netip.MustParseAddr("192.0.2.1"),
|
||||
netip.MustParseAddr("2001:db8::1"),
|
||||
})
|
||||
serialized, err := input.Value()
|
||||
c.Assert(err, check.IsNil)
|
||||
if serial, ok := serialized.(string); ok {
|
||||
c.Assert(serial, check.Equals, "192.0.2.1,2001:db8::1")
|
||||
}
|
||||
|
||||
var deserialized types.MachineAddresses
|
||||
err = deserialized.Scan(serialized)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(deserialized), check.Equals, len(input))
|
||||
for i := range deserialized {
|
||||
c.Assert(deserialized[i], check.Equals, input[i])
|
||||
}
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGenerateGivenName(c *check.C) {
|
||||
user1, err := db.CreateUser("user-1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("user-1", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
machine := &types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "machine-key-1",
|
||||
NodeKey: "node-key-1",
|
||||
DiscoKey: "disco-key-1",
|
||||
Hostname: "hostname-1",
|
||||
GivenName: "hostname-1",
|
||||
UserID: user1.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(machine)
|
||||
|
||||
givenName, err := db.GenerateGivenName("machine-key-2", "hostname-2")
|
||||
comment := check.Commentf("Same user, unique machines, unique hostnames, no conflict")
|
||||
c.Assert(err, check.IsNil, comment)
|
||||
c.Assert(givenName, check.Equals, "hostname-2", comment)
|
||||
|
||||
givenName, err = db.GenerateGivenName("machine-key-1", "hostname-1")
|
||||
comment = check.Commentf("Same user, same machine, same hostname, no conflict")
|
||||
c.Assert(err, check.IsNil, comment)
|
||||
c.Assert(givenName, check.Equals, "hostname-1", comment)
|
||||
|
||||
givenName, err = db.GenerateGivenName("machine-key-2", "hostname-1")
|
||||
comment = check.Commentf("Same user, unique machines, same hostname, conflict")
|
||||
c.Assert(err, check.IsNil, comment)
|
||||
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment)
|
||||
|
||||
givenName, err = db.GenerateGivenName("machine-key-2", "hostname-1")
|
||||
comment = check.Commentf("Unique users, unique machines, same hostname, conflict")
|
||||
c.Assert(err, check.IsNil, comment)
|
||||
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestSetTags(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
machine := &types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(machine)
|
||||
|
||||
// assign simple tags
|
||||
sTags := []string{"tag:test", "tag:foo"}
|
||||
err = db.SetTags(machine, sTags)
|
||||
c.Assert(err, check.IsNil)
|
||||
machine, err = db.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(machine.ForcedTags, check.DeepEquals, types.StringList(sTags))
|
||||
|
||||
// assign duplicat tags, expect no errors but no doubles in DB
|
||||
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
|
||||
err = db.SetTags(machine, eTags)
|
||||
c.Assert(err, check.IsNil)
|
||||
machine, err = db.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(
|
||||
machine.ForcedTags,
|
||||
check.DeepEquals,
|
||||
types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}),
|
||||
)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(4))
|
||||
}
|
||||
|
||||
func TestHeadscale_generateGivenName(t *testing.T) {
|
||||
type args struct {
|
||||
suppliedName string
|
||||
randomSuffix bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
db *HSDatabase
|
||||
args args
|
||||
want *regexp.Regexp
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple machine name generation",
|
||||
db: &HSDatabase{
|
||||
stripEmailDomain: true,
|
||||
},
|
||||
args: args{
|
||||
suppliedName: "testmachine",
|
||||
randomSuffix: false,
|
||||
},
|
||||
want: regexp.MustCompile("^testmachine$"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "machine name with 53 chars",
|
||||
db: &HSDatabase{
|
||||
stripEmailDomain: true,
|
||||
},
|
||||
args: args{
|
||||
suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine",
|
||||
randomSuffix: false,
|
||||
},
|
||||
want: regexp.MustCompile("^testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine$"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "machine name with 63 chars",
|
||||
db: &HSDatabase{
|
||||
stripEmailDomain: true,
|
||||
},
|
||||
args: args{
|
||||
suppliedName: "machineeee12345678901234567890123456789012345678901234567890123",
|
||||
randomSuffix: false,
|
||||
},
|
||||
want: regexp.MustCompile("^machineeee12345678901234567890123456789012345678901234567890123$"),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "machine name with 64 chars",
|
||||
db: &HSDatabase{
|
||||
stripEmailDomain: true,
|
||||
},
|
||||
args: args{
|
||||
suppliedName: "machineeee123456789012345678901234567890123456789012345678901234",
|
||||
randomSuffix: false,
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "machine name with 73 chars",
|
||||
db: &HSDatabase{
|
||||
stripEmailDomain: true,
|
||||
},
|
||||
args: args{
|
||||
suppliedName: "machineeee123456789012345678901234567890123456789012345678901234567890123",
|
||||
randomSuffix: false,
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "machine name with random suffix",
|
||||
db: &HSDatabase{
|
||||
stripEmailDomain: true,
|
||||
},
|
||||
args: args{
|
||||
suppliedName: "test",
|
||||
randomSuffix: true,
|
||||
},
|
||||
want: regexp.MustCompile(fmt.Sprintf("^test-[a-z0-9]{%d}$", MachineGivenNameHashLength)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "machine name with 63 chars with random suffix",
|
||||
db: &HSDatabase{
|
||||
stripEmailDomain: true,
|
||||
},
|
||||
args: args{
|
||||
suppliedName: "machineeee12345678901234567890123456789012345678901234567890123",
|
||||
randomSuffix: true,
|
||||
},
|
||||
want: regexp.MustCompile(fmt.Sprintf("^machineeee1234567890123456789012345678901234567890123-[a-z0-9]{%d}$", MachineGivenNameHashLength)),
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.db.generateGivenName(tt.args.suppliedName, tt.args.randomSuffix)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf(
|
||||
"Headscale.GenerateGivenName() error = %v, wantErr %v",
|
||||
err,
|
||||
tt.wantErr,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if tt.want != nil && !tt.want.MatchString(got) {
|
||||
t.Errorf(
|
||||
"Headscale.GenerateGivenName() = %v, does not match %v",
|
||||
tt.want,
|
||||
got,
|
||||
)
|
||||
}
|
||||
|
||||
if len(got) > util.LabelHostnameLength {
|
||||
t.Errorf(
|
||||
"Headscale.GenerateGivenName() = %v is larger than allowed DNS segment %d",
|
||||
got,
|
||||
util.LabelHostnameLength,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Suite) TestAutoApproveRoutes(c *check.C) {
|
||||
acl := []byte(`
|
||||
{
|
||||
"tagOwners": {
|
||||
"tag:exit": ["test"],
|
||||
},
|
||||
|
||||
"groups": {
|
||||
"group:test": ["test"]
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{"action": "accept", "users": ["*"], "ports": ["*:*"]},
|
||||
],
|
||||
|
||||
"autoApprovers": {
|
||||
"exitNode": ["tag:exit"],
|
||||
"routes": {
|
||||
"10.10.0.0/16": ["group:test"],
|
||||
"10.11.0.0/16": ["test"],
|
||||
}
|
||||
}
|
||||
}
|
||||
`)
|
||||
|
||||
pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(pol, check.NotNil)
|
||||
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
|
||||
defaultRouteV4 := netip.MustParsePrefix("0.0.0.0/0")
|
||||
defaultRouteV6 := netip.MustParsePrefix("::/0")
|
||||
route1 := netip.MustParsePrefix("10.10.0.0/16")
|
||||
// Check if a subprefix of an autoapproved route is approved
|
||||
route2 := netip.MustParsePrefix("10.11.0.0/24")
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
|
||||
DiscoKey: "faa",
|
||||
Hostname: "test",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: types.HostInfo{
|
||||
RequestTags: []string{"tag:exit"},
|
||||
RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2},
|
||||
},
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
||||
}
|
||||
|
||||
db.db.Save(&machine)
|
||||
|
||||
err = db.ProcessMachineRoutes(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machine0ByID, err := db.GetMachineByID(0)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.EnableAutoApprovedRoutes(pol, machine0ByID)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes, err := db.GetEnabledRoutes(machine0ByID)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(enabledRoutes, check.HasLen, 4)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(4))
|
||||
}
|
||||
|
||||
func TestMachine_canAccess(t *testing.T) {
|
||||
type args struct {
|
||||
filter []tailcfg.FilterRule
|
||||
machine2 *types.Machine
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
machine types.Machine
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "no-rules",
|
||||
machine: types.Machine{
|
||||
IPAddresses: types.MachineAddresses{
|
||||
netip.MustParseAddr("10.0.0.1"),
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
filter: []tailcfg.FilterRule{},
|
||||
machine2: &types.Machine{
|
||||
IPAddresses: types.MachineAddresses{
|
||||
netip.MustParseAddr("10.0.0.2"),
|
||||
},
|
||||
},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard",
|
||||
machine: types.Machine{
|
||||
IPAddresses: types.MachineAddresses{
|
||||
netip.MustParseAddr("10.0.0.1"),
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
filter: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"*"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{
|
||||
IP: "*",
|
||||
Ports: tailcfg.PortRange{
|
||||
First: 0,
|
||||
Last: 65535,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
machine2: &types.Machine{
|
||||
IPAddresses: types.MachineAddresses{
|
||||
netip.MustParseAddr("10.0.0.2"),
|
||||
},
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "explicit-m1-to-m2",
|
||||
machine: types.Machine{
|
||||
IPAddresses: types.MachineAddresses{
|
||||
netip.MustParseAddr("10.0.0.1"),
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
filter: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"10.0.0.1"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{
|
||||
IP: "10.0.0.2",
|
||||
Ports: tailcfg.PortRange{
|
||||
First: 0,
|
||||
Last: 65535,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
machine2: &types.Machine{
|
||||
IPAddresses: types.MachineAddresses{
|
||||
netip.MustParseAddr("10.0.0.2"),
|
||||
},
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "explicit-m2-to-m1",
|
||||
machine: types.Machine{
|
||||
IPAddresses: types.MachineAddresses{
|
||||
netip.MustParseAddr("10.0.0.1"),
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
filter: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"10.0.0.2"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{
|
||||
IP: "10.0.0.1",
|
||||
Ports: tailcfg.PortRange{
|
||||
First: 0,
|
||||
Last: 65535,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
machine2: &types.Machine{
|
||||
IPAddresses: types.MachineAddresses{
|
||||
netip.MustParseAddr("10.0.0.2"),
|
||||
},
|
||||
},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.machine.CanAccess(tt.args.filter, tt.args.machine2); got != tt.want {
|
||||
t.Errorf("Machine.CanAccess() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -1,17 +1,14 @@
|
||||
package hscontrol
|
||||
package db
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@ -23,28 +20,6 @@ var (
|
||||
ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid")
|
||||
)
|
||||
|
||||
// PreAuthKey describes a pre-authorization key usable in a particular user.
|
||||
type PreAuthKey struct {
|
||||
ID uint64 `gorm:"primary_key"`
|
||||
Key string
|
||||
UserID uint
|
||||
User User
|
||||
Reusable bool
|
||||
Ephemeral bool `gorm:"default:false"`
|
||||
Used bool `gorm:"default:false"`
|
||||
ACLTags []PreAuthKeyACLTag
|
||||
|
||||
CreatedAt *time.Time
|
||||
Expiration *time.Time
|
||||
}
|
||||
|
||||
// PreAuthKeyACLTag describes an autmatic tag applied to a node when registered with the associated PreAuthKey.
|
||||
type PreAuthKeyACLTag struct {
|
||||
ID uint64 `gorm:"primary_key"`
|
||||
PreAuthKeyID uint64
|
||||
Tag string
|
||||
}
|
||||
|
||||
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
|
||||
func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||
userName string,
|
||||
@ -52,7 +27,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||
ephemeral bool,
|
||||
expiration *time.Time,
|
||||
aclTags []string,
|
||||
) (*PreAuthKey, error) {
|
||||
) (*types.PreAuthKey, error) {
|
||||
user, err := hsdb.GetUser(userName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -74,7 +49,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := PreAuthKey{
|
||||
key := types.PreAuthKey{
|
||||
Key: kstr,
|
||||
UserID: user.ID,
|
||||
User: *user,
|
||||
@ -94,7 +69,7 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||
|
||||
for _, tag := range aclTags {
|
||||
if !seenTags[tag] {
|
||||
if err := db.Save(&PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
|
||||
if err := db.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to ceate key tag in the database: %w",
|
||||
err,
|
||||
@ -116,14 +91,14 @@ func (hsdb *HSDatabase) CreatePreAuthKey(
|
||||
}
|
||||
|
||||
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
|
||||
func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]PreAuthKey, error) {
|
||||
func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) {
|
||||
user, err := hsdb.GetUser(userName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys := []PreAuthKey{}
|
||||
if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
|
||||
keys := []types.PreAuthKey{}
|
||||
if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -131,8 +106,8 @@ func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]PreAuthKey, error) {
|
||||
}
|
||||
|
||||
// GetPreAuthKey returns a PreAuthKey for a given key.
|
||||
func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*PreAuthKey, error) {
|
||||
pak, err := hsdb.checkKeyValidity(key)
|
||||
func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKey, error) {
|
||||
pak, err := hsdb.ValidatePreAuthKey(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -146,9 +121,9 @@ func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*PreAuthKey, err
|
||||
|
||||
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
|
||||
// does not exist.
|
||||
func (hsdb *HSDatabase) DestroyPreAuthKey(pak PreAuthKey) error {
|
||||
func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error {
|
||||
return hsdb.db.Transaction(func(db *gorm.DB) error {
|
||||
if result := db.Unscoped().Where(PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&PreAuthKeyACLTag{}); result.Error != nil {
|
||||
if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
@ -161,7 +136,7 @@ func (hsdb *HSDatabase) DestroyPreAuthKey(pak PreAuthKey) error {
|
||||
}
|
||||
|
||||
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
||||
func (hsdb *HSDatabase) ExpirePreAuthKey(k *PreAuthKey) error {
|
||||
func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
|
||||
if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
@ -170,7 +145,7 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *PreAuthKey) error {
|
||||
}
|
||||
|
||||
// UsePreAuthKey marks a PreAuthKey as used.
|
||||
func (hsdb *HSDatabase) UsePreAuthKey(k *PreAuthKey) error {
|
||||
func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error {
|
||||
k.Used = true
|
||||
if err := hsdb.db.Save(k).Error; err != nil {
|
||||
return fmt.Errorf("failed to update key used status in the database: %w", err)
|
||||
@ -179,10 +154,10 @@ func (hsdb *HSDatabase) UsePreAuthKey(k *PreAuthKey) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node
|
||||
// ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node
|
||||
// If returns no error and a PreAuthKey, it can be used.
|
||||
func (hsdb *HSDatabase) checkKeyValidity(k string) (*PreAuthKey, error) {
|
||||
pak := PreAuthKey{}
|
||||
func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) {
|
||||
pak := types.PreAuthKey{}
|
||||
if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is(
|
||||
result.Error,
|
||||
gorm.ErrRecordNotFound,
|
||||
@ -198,8 +173,8 @@ func (hsdb *HSDatabase) checkKeyValidity(k string) (*PreAuthKey, error) {
|
||||
return &pak, nil
|
||||
}
|
||||
|
||||
machines := []Machine{}
|
||||
if err := hsdb.db.Preload("AuthKey").Where(&Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil {
|
||||
machines := types.Machines{}
|
||||
if err := hsdb.db.Preload("AuthKey").Where(&types.Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -219,29 +194,3 @@ func (hsdb *HSDatabase) generateKey() (string, error) {
|
||||
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
func (key *PreAuthKey) toProto() *v1.PreAuthKey {
|
||||
protoKey := v1.PreAuthKey{
|
||||
User: key.User.Name,
|
||||
Id: strconv.FormatUint(key.ID, util.Base10),
|
||||
Key: key.Key,
|
||||
Ephemeral: key.Ephemeral,
|
||||
Reusable: key.Reusable,
|
||||
Used: key.Used,
|
||||
AclTags: make([]string, len(key.ACLTags)),
|
||||
}
|
||||
|
||||
if key.Expiration != nil {
|
||||
protoKey.Expiration = timestamppb.New(*key.Expiration)
|
||||
}
|
||||
|
||||
if key.CreatedAt != nil {
|
||||
protoKey.CreatedAt = timestamppb.New(*key.CreatedAt)
|
||||
}
|
||||
|
||||
for idx := range key.ACLTags {
|
||||
protoKey.AclTags[idx] = key.ACLTags[idx].Tag
|
||||
}
|
||||
|
||||
return &protoKey
|
||||
}
|
@ -1,20 +1,22 @@
|
||||
package hscontrol
|
||||
package db
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"gopkg.in/check.v1"
|
||||
)
|
||||
|
||||
func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
||||
_, err := app.db.CreatePreAuthKey("bogus", true, false, nil, nil)
|
||||
_, err := db.CreatePreAuthKey("bogus", true, false, nil, nil)
|
||||
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
user, err := app.db.CreateUser("test")
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
key, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
||||
key, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
// Did we get a valid key?
|
||||
@ -24,10 +26,10 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
||||
// Make sure the User association is populated
|
||||
c.Assert(key.User.Name, check.Equals, user.Name)
|
||||
|
||||
_, err = app.db.ListPreAuthKeys("bogus")
|
||||
_, err = db.ListPreAuthKeys("bogus")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
keys, err := app.db.ListPreAuthKeys(user.Name)
|
||||
keys, err := db.ListPreAuthKeys(user.Name)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(keys), check.Equals, 1)
|
||||
|
||||
@ -36,174 +38,176 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) {
|
||||
}
|
||||
|
||||
func (*Suite) TestExpiredPreAuthKey(c *check.C) {
|
||||
user, err := app.db.CreateUser("test2")
|
||||
user, err := db.CreateUser("test2")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
now := time.Now()
|
||||
pak, err := app.db.CreatePreAuthKey(user.Name, true, false, &now, nil)
|
||||
pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
key, err := app.db.checkKeyValidity(pak.Key)
|
||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||
c.Assert(err, check.Equals, ErrPreAuthKeyExpired)
|
||||
c.Assert(key, check.IsNil)
|
||||
}
|
||||
|
||||
func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) {
|
||||
key, err := app.db.checkKeyValidity("potatoKey")
|
||||
key, err := db.ValidatePreAuthKey("potatoKey")
|
||||
c.Assert(err, check.Equals, ErrPreAuthKeyNotFound)
|
||||
c.Assert(key, check.IsNil)
|
||||
}
|
||||
|
||||
func (*Suite) TestValidateKeyOk(c *check.C) {
|
||||
user, err := app.db.CreateUser("test3")
|
||||
user, err := db.CreateUser("test3")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
||||
pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
key, err := app.db.checkKeyValidity(pak.Key)
|
||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(key.ID, check.Equals, pak.ID)
|
||||
}
|
||||
|
||||
func (*Suite) TestAlreadyUsedKey(c *check.C) {
|
||||
user, err := app.db.CreateUser("test4")
|
||||
user, err := db.CreateUser("test4")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machine := Machine{
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testest",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
app.db.db.Save(&machine)
|
||||
db.db.Save(&machine)
|
||||
|
||||
key, err := app.db.checkKeyValidity(pak.Key)
|
||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
|
||||
c.Assert(key, check.IsNil)
|
||||
}
|
||||
|
||||
func (*Suite) TestReusableBeingUsedKey(c *check.C) {
|
||||
user, err := app.db.CreateUser("test5")
|
||||
user, err := db.CreateUser("test5")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
||||
pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machine := Machine{
|
||||
machine := types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testest",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
app.db.db.Save(&machine)
|
||||
db.db.Save(&machine)
|
||||
|
||||
key, err := app.db.checkKeyValidity(pak.Key)
|
||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(key.ID, check.Equals, pak.ID)
|
||||
}
|
||||
|
||||
func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) {
|
||||
user, err := app.db.CreateUser("test6")
|
||||
user, err := db.CreateUser("test6")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
key, err := app.db.checkKeyValidity(pak.Key)
|
||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(key.ID, check.Equals, pak.ID)
|
||||
}
|
||||
|
||||
func (*Suite) TestEphemeralKey(c *check.C) {
|
||||
user, err := app.db.CreateUser("test7")
|
||||
user, err := db.CreateUser("test7")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, true, nil, nil)
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
now := time.Now()
|
||||
machine := Machine{
|
||||
now := time.Now().Add(-time.Second * 30)
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testest",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
LastSeen: &now,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
app.db.db.Save(&machine)
|
||||
db.db.Save(&machine)
|
||||
|
||||
_, err = app.db.checkKeyValidity(pak.Key)
|
||||
_, err = db.ValidatePreAuthKey(pak.Key)
|
||||
// Ephemeral keys are by definition reusable
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetMachine("test7", "testest")
|
||||
_, err = db.GetMachine("test7", "testest")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
app.expireEphemeralNodesWorker()
|
||||
db.ExpireEphemeralMachines(time.Second * 20)
|
||||
|
||||
// The machine record should have been deleted
|
||||
_, err = app.db.GetMachine("test7", "testest")
|
||||
_, err = db.GetMachine("test7", "testest")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(1))
|
||||
}
|
||||
|
||||
func (*Suite) TestExpirePreauthKey(c *check.C) {
|
||||
user, err := app.db.CreateUser("test3")
|
||||
user, err := db.CreateUser("test3")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := app.db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
||||
pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(pak.Expiration, check.IsNil)
|
||||
|
||||
err = app.db.ExpirePreAuthKey(pak)
|
||||
err = db.ExpirePreAuthKey(pak)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(pak.Expiration, check.NotNil)
|
||||
|
||||
key, err := app.db.checkKeyValidity(pak.Key)
|
||||
key, err := db.ValidatePreAuthKey(pak.Key)
|
||||
c.Assert(err, check.Equals, ErrPreAuthKeyExpired)
|
||||
c.Assert(key, check.IsNil)
|
||||
}
|
||||
|
||||
func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
|
||||
user, err := app.db.CreateUser("test6")
|
||||
user, err := db.CreateUser("test6")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
pak.Used = true
|
||||
app.db.db.Save(&pak)
|
||||
db.db.Save(&pak)
|
||||
|
||||
_, err = app.db.checkKeyValidity(pak.Key)
|
||||
_, err = db.ValidatePreAuthKey(pak.Key)
|
||||
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
|
||||
}
|
||||
|
||||
func (*Suite) TestPreAuthKeyACLTags(c *check.C) {
|
||||
user, err := app.db.CreateUser("test8")
|
||||
user, err := db.CreateUser("test8")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"})
|
||||
_, err = db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"})
|
||||
c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected
|
||||
|
||||
tags := []string{"tag:test1", "tag:test2"}
|
||||
tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"}
|
||||
_, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate)
|
||||
_, err = db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
listedPaks, err := app.db.ListPreAuthKeys("test8")
|
||||
listedPaks, err := db.ListPreAuthKeys("test8")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(listedPaks[0].toProto().AclTags, check.DeepEquals, tags)
|
||||
c.Assert(listedPaks[0].Proto().AclTags, check.DeepEquals, tags)
|
||||
}
|
@ -1,55 +1,19 @@
|
||||
package hscontrol
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/rs/zerolog/log"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrRouteIsNotAvailable = errors.New("route is not available")
|
||||
ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0")
|
||||
ExitRouteV6 = netip.MustParsePrefix("::/0")
|
||||
)
|
||||
var ErrRouteIsNotAvailable = errors.New("route is not available")
|
||||
|
||||
type Route struct {
|
||||
gorm.Model
|
||||
|
||||
MachineID uint64
|
||||
Machine Machine
|
||||
Prefix IPPrefix
|
||||
|
||||
Advertised bool
|
||||
Enabled bool
|
||||
IsPrimary bool
|
||||
}
|
||||
|
||||
type Routes []Route
|
||||
|
||||
func (r *Route) String() string {
|
||||
return fmt.Sprintf("%s:%s", r.Machine, netip.Prefix(r.Prefix).String())
|
||||
}
|
||||
|
||||
func (r *Route) isExitRoute() bool {
|
||||
return netip.Prefix(r.Prefix) == ExitRouteV4 || netip.Prefix(r.Prefix) == ExitRouteV6
|
||||
}
|
||||
|
||||
func (rs Routes) toPrefixes() []netip.Prefix {
|
||||
prefixes := make([]netip.Prefix, len(rs))
|
||||
for i, r := range rs {
|
||||
prefixes[i] = netip.Prefix(r.Prefix)
|
||||
}
|
||||
|
||||
return prefixes
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetRoutes() ([]Route, error) {
|
||||
var routes []Route
|
||||
func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := hsdb.db.Preload("Machine").Find(&routes).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -58,8 +22,21 @@ func (hsdb *HSDatabase) GetRoutes() ([]Route, error) {
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetMachineRoutes(m *Machine) ([]Route, error) {
|
||||
var routes []Route
|
||||
func (hsdb *HSDatabase) GetMachineAdvertisedRoutes(machine *types.Machine) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := hsdb.db.
|
||||
Preload("Machine").
|
||||
Where("machine_id = ? AND advertised = true", machine.ID).
|
||||
Find(&routes).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetMachineRoutes(m *types.Machine) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := hsdb.db.
|
||||
Preload("Machine").
|
||||
Where("machine_id = ?", m.ID).
|
||||
@ -71,8 +48,8 @@ func (hsdb *HSDatabase) GetMachineRoutes(m *Machine) ([]Route, error) {
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetRoute(id uint64) (*Route, error) {
|
||||
var route Route
|
||||
func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) {
|
||||
var route types.Route
|
||||
err := hsdb.db.Preload("Machine").First(&route, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -90,8 +67,12 @@ func (hsdb *HSDatabase) EnableRoute(id uint64) error {
|
||||
// Tailscale requires both IPv4 and IPv6 exit routes to
|
||||
// be enabled at the same time, as per
|
||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||
if route.isExitRoute() {
|
||||
return hsdb.enableRoutes(&route.Machine, ExitRouteV4.String(), ExitRouteV6.String())
|
||||
if route.IsExitRoute() {
|
||||
return hsdb.enableRoutes(
|
||||
&route.Machine,
|
||||
types.ExitRouteV4.String(),
|
||||
types.ExitRouteV6.String(),
|
||||
)
|
||||
}
|
||||
|
||||
return hsdb.enableRoutes(&route.Machine, netip.Prefix(route.Prefix).String())
|
||||
@ -106,7 +87,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
|
||||
// Tailscale requires both IPv4 and IPv6 exit routes to
|
||||
// be enabled at the same time, as per
|
||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||
if !route.isExitRoute() {
|
||||
if !route.IsExitRoute() {
|
||||
route.Enabled = false
|
||||
route.IsPrimary = false
|
||||
err = hsdb.db.Save(route).Error
|
||||
@ -114,7 +95,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return hsdb.handlePrimarySubnetFailover()
|
||||
return hsdb.HandlePrimarySubnetFailover()
|
||||
}
|
||||
|
||||
routes, err := hsdb.GetMachineRoutes(&route.Machine)
|
||||
@ -123,7 +104,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
|
||||
}
|
||||
|
||||
for i := range routes {
|
||||
if routes[i].isExitRoute() {
|
||||
if routes[i].IsExitRoute() {
|
||||
routes[i].Enabled = false
|
||||
routes[i].IsPrimary = false
|
||||
err = hsdb.db.Save(&routes[i]).Error
|
||||
@ -133,7 +114,7 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
|
||||
}
|
||||
}
|
||||
|
||||
return hsdb.handlePrimarySubnetFailover()
|
||||
return hsdb.HandlePrimarySubnetFailover()
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
|
||||
@ -145,12 +126,12 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
|
||||
// Tailscale requires both IPv4 and IPv6 exit routes to
|
||||
// be enabled at the same time, as per
|
||||
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
|
||||
if !route.isExitRoute() {
|
||||
if !route.IsExitRoute() {
|
||||
if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return hsdb.handlePrimarySubnetFailover()
|
||||
return hsdb.HandlePrimarySubnetFailover()
|
||||
}
|
||||
|
||||
routes, err := hsdb.GetMachineRoutes(&route.Machine)
|
||||
@ -158,9 +139,9 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
|
||||
return err
|
||||
}
|
||||
|
||||
routesToDelete := []Route{}
|
||||
routesToDelete := types.Routes{}
|
||||
for _, r := range routes {
|
||||
if r.isExitRoute() {
|
||||
if r.IsExitRoute() {
|
||||
routesToDelete = append(routesToDelete, r)
|
||||
}
|
||||
}
|
||||
@ -169,10 +150,10 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return hsdb.handlePrimarySubnetFailover()
|
||||
return hsdb.HandlePrimarySubnetFailover()
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) DeleteMachineRoutes(m *Machine) error {
|
||||
func (hsdb *HSDatabase) DeleteMachineRoutes(m *types.Machine) error {
|
||||
routes, err := hsdb.GetMachineRoutes(m)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -184,14 +165,14 @@ func (hsdb *HSDatabase) DeleteMachineRoutes(m *Machine) error {
|
||||
}
|
||||
}
|
||||
|
||||
return hsdb.handlePrimarySubnetFailover()
|
||||
return hsdb.HandlePrimarySubnetFailover()
|
||||
}
|
||||
|
||||
// isUniquePrefix returns if there is another machine providing the same route already.
|
||||
func (hsdb *HSDatabase) isUniquePrefix(route Route) bool {
|
||||
func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool {
|
||||
var count int64
|
||||
hsdb.db.
|
||||
Model(&Route{}).
|
||||
Model(&types.Route{}).
|
||||
Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?",
|
||||
route.Prefix,
|
||||
route.MachineID,
|
||||
@ -200,11 +181,11 @@ func (hsdb *HSDatabase) isUniquePrefix(route Route) bool {
|
||||
return count == 0
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*Route, error) {
|
||||
var route Route
|
||||
func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, error) {
|
||||
var route types.Route
|
||||
err := hsdb.db.
|
||||
Preload("Machine").
|
||||
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", IPPrefix(prefix), true, true, true).
|
||||
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true).
|
||||
First(&route).Error
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err
|
||||
@ -219,8 +200,8 @@ func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*Route, error) {
|
||||
|
||||
// getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover)
|
||||
// Exit nodes are not considered for this, as they are never marked as Primary.
|
||||
func (hsdb *HSDatabase) getMachinePrimaryRoutes(m *Machine) ([]Route, error) {
|
||||
var routes []Route
|
||||
func (hsdb *HSDatabase) GetMachinePrimaryRoutes(m *types.Machine) (types.Routes, error) {
|
||||
var routes types.Routes
|
||||
err := hsdb.db.
|
||||
Preload("Machine").
|
||||
Where("machine_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", m.ID, true, true, true).
|
||||
@ -232,8 +213,8 @@ func (hsdb *HSDatabase) getMachinePrimaryRoutes(m *Machine) ([]Route, error) {
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) processMachineRoutes(machine *Machine) error {
|
||||
currentRoutes := []Route{}
|
||||
func (hsdb *HSDatabase) ProcessMachineRoutes(machine *types.Machine) error {
|
||||
currentRoutes := types.Routes{}
|
||||
err := hsdb.db.Where("machine_id = ?", machine.ID).Find(¤tRoutes).Error
|
||||
if err != nil {
|
||||
return err
|
||||
@ -266,9 +247,9 @@ func (hsdb *HSDatabase) processMachineRoutes(machine *Machine) error {
|
||||
|
||||
for prefix, exists := range advertisedRoutes {
|
||||
if !exists {
|
||||
route := Route{
|
||||
route := types.Route{
|
||||
MachineID: machine.ID,
|
||||
Prefix: IPPrefix(prefix),
|
||||
Prefix: types.IPPrefix(prefix),
|
||||
Advertised: true,
|
||||
Enabled: false,
|
||||
}
|
||||
@ -282,9 +263,9 @@ func (hsdb *HSDatabase) processMachineRoutes(machine *Machine) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
|
||||
func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error {
|
||||
// first, get all the enabled routes
|
||||
var routes []Route
|
||||
var routes types.Routes
|
||||
err := hsdb.db.
|
||||
Preload("Machine").
|
||||
Where("advertised = ? AND enabled = ?", true, true).
|
||||
@ -295,7 +276,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
|
||||
|
||||
routesChanged := false
|
||||
for pos, route := range routes {
|
||||
if route.isExitRoute() {
|
||||
if route.IsExitRoute() {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -321,7 +302,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
|
||||
}
|
||||
|
||||
if route.IsPrimary {
|
||||
if route.Machine.isOnline() {
|
||||
if route.Machine.IsOnline() {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -332,7 +313,7 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
|
||||
Msgf("machine offline, finding a new primary subnet")
|
||||
|
||||
// find a new primary route
|
||||
var newPrimaryRoutes []Route
|
||||
var newPrimaryRoutes types.Routes
|
||||
err := hsdb.db.
|
||||
Preload("Machine").
|
||||
Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?",
|
||||
@ -346,9 +327,9 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
|
||||
return err
|
||||
}
|
||||
|
||||
var newPrimaryRoute *Route
|
||||
var newPrimaryRoute *types.Route
|
||||
for pos, r := range newPrimaryRoutes {
|
||||
if r.Machine.isOnline() {
|
||||
if r.Machine.IsOnline() {
|
||||
newPrimaryRoute = &newPrimaryRoutes[pos]
|
||||
|
||||
break
|
||||
@ -399,27 +380,78 @@ func (hsdb *HSDatabase) handlePrimarySubnetFailover() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rs Routes) toProto() []*v1.Route {
|
||||
protoRoutes := []*v1.Route{}
|
||||
|
||||
for _, route := range rs {
|
||||
protoRoute := v1.Route{
|
||||
Id: uint64(route.ID),
|
||||
Machine: route.Machine.toProto(),
|
||||
Prefix: netip.Prefix(route.Prefix).String(),
|
||||
Advertised: route.Advertised,
|
||||
Enabled: route.Enabled,
|
||||
IsPrimary: route.IsPrimary,
|
||||
CreatedAt: timestamppb.New(route.CreatedAt),
|
||||
UpdatedAt: timestamppb.New(route.UpdatedAt),
|
||||
}
|
||||
|
||||
if route.DeletedAt.Valid {
|
||||
protoRoute.DeletedAt = timestamppb.New(route.DeletedAt.Time)
|
||||
}
|
||||
|
||||
protoRoutes = append(protoRoutes, &protoRoute)
|
||||
// EnableAutoApprovedRoutes enables any routes advertised by a machine that match the ACL autoApprovers policy.
|
||||
func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
||||
aclPolicy *policy.ACLPolicy,
|
||||
machine *types.Machine,
|
||||
) error {
|
||||
if len(machine.IPAddresses) == 0 {
|
||||
return nil // This machine has no IPAddresses, so can't possibly match any autoApprovers ACLs
|
||||
}
|
||||
|
||||
return protoRoutes
|
||||
routes, err := hsdb.GetMachineAdvertisedRoutes(machine)
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("Could not get advertised routes for machine")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
approvedRoutes := types.Routes{}
|
||||
|
||||
for _, advertisedRoute := range routes {
|
||||
if advertisedRoute.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers(
|
||||
netip.Prefix(advertisedRoute.Prefix),
|
||||
)
|
||||
if err != nil {
|
||||
log.Err(err).
|
||||
Str("advertisedRoute", advertisedRoute.String()).
|
||||
Uint64("machineId", machine.ID).
|
||||
Msg("Failed to resolve autoApprovers for advertised route")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
for _, approvedAlias := range routeApprovers {
|
||||
if approvedAlias == machine.User.Name {
|
||||
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
||||
} else {
|
||||
// TODO(kradalby): figure out how to get this to depend on less stuff
|
||||
approvedIps, err := aclPolicy.ExpandAlias(types.Machines{*machine}, approvedAlias, hsdb.stripEmailDomain)
|
||||
if err != nil {
|
||||
log.Err(err).
|
||||
Str("alias", approvedAlias).
|
||||
Msg("Failed to expand alias when processing autoApprovers policy")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// approvedIPs should contain all of machine's IPs if it matches the rule, so check for first
|
||||
if approvedIps.Contains(machine.IPAddresses[0]) {
|
||||
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, approvedRoute := range approvedRoutes {
|
||||
err := hsdb.EnableRoute(uint64(approvedRoute.ID))
|
||||
if err != nil {
|
||||
log.Err(err).
|
||||
Str("approvedRoute", approvedRoute.String()).
|
||||
Uint64("machineId", machine.ID).
|
||||
Msg("Failed to enable approved route")
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -1,9 +1,11 @@
|
||||
package hscontrol
|
||||
package db
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"gopkg.in/check.v1"
|
||||
"tailscale.com/tailcfg"
|
||||
@ -11,13 +13,13 @@ import (
|
||||
)
|
||||
|
||||
func (s *Suite) TestGetRoutes(c *check.C) {
|
||||
user, err := app.db.CreateUser("test")
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetMachine("test", "test_get_route_machine")
|
||||
_, err = db.GetMachine("test", "test_get_route_machine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
route, err := netip.ParsePrefix("10.0.0.0/24")
|
||||
@ -27,41 +29,43 @@ func (s *Suite) TestGetRoutes(c *check.C) {
|
||||
RoutableIPs: []netip.Prefix{route},
|
||||
}
|
||||
|
||||
machine := Machine{
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "test_get_route_machine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: HostInfo(hostInfo),
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
}
|
||||
app.db.db.Save(&machine)
|
||||
db.db.Save(&machine)
|
||||
|
||||
err = app.db.processMachineRoutes(&machine)
|
||||
err = db.ProcessMachineRoutes(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
advertisedRoutes, err := app.db.GetAdvertisedRoutes(&machine)
|
||||
advertisedRoutes, err := db.GetAdvertisedRoutes(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(advertisedRoutes), check.Equals, 1)
|
||||
|
||||
err = app.db.enableRoutes(&machine, "192.168.0.0/24")
|
||||
err = db.enableRoutes(&machine, "192.168.0.0/24")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
err = app.db.enableRoutes(&machine, "10.0.0.0/24")
|
||||
err = db.enableRoutes(&machine, "10.0.0.0/24")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
||||
user, err := app.db.CreateUser("test")
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetMachine("test", "test_enable_route_machine")
|
||||
_, err = db.GetMachine("test", "test_enable_route_machine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
route, err := netip.ParsePrefix(
|
||||
@ -78,65 +82,67 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
||||
RoutableIPs: []netip.Prefix{route, route2},
|
||||
}
|
||||
|
||||
machine := Machine{
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "test_enable_route_machine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: HostInfo(hostInfo),
|
||||
HostInfo: types.HostInfo(hostInfo),
|
||||
}
|
||||
app.db.db.Save(&machine)
|
||||
db.db.Save(&machine)
|
||||
|
||||
err = app.db.processMachineRoutes(&machine)
|
||||
err = db.ProcessMachineRoutes(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
availableRoutes, err := app.db.GetAdvertisedRoutes(&machine)
|
||||
availableRoutes, err := db.GetAdvertisedRoutes(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(availableRoutes), check.Equals, 2)
|
||||
|
||||
noEnabledRoutes, err := app.db.GetEnabledRoutes(&machine)
|
||||
noEnabledRoutes, err := db.GetEnabledRoutes(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(noEnabledRoutes), check.Equals, 0)
|
||||
|
||||
err = app.db.enableRoutes(&machine, "192.168.0.0/24")
|
||||
err = db.enableRoutes(&machine, "192.168.0.0/24")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
err = app.db.enableRoutes(&machine, "10.0.0.0/24")
|
||||
err = db.enableRoutes(&machine, "10.0.0.0/24")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes, err := app.db.GetEnabledRoutes(&machine)
|
||||
enabledRoutes, err := db.GetEnabledRoutes(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes), check.Equals, 1)
|
||||
|
||||
// Adding it twice will just let it pass through
|
||||
err = app.db.enableRoutes(&machine, "10.0.0.0/24")
|
||||
err = db.enableRoutes(&machine, "10.0.0.0/24")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enableRoutesAfterDoubleApply, err := app.db.GetEnabledRoutes(&machine)
|
||||
enableRoutesAfterDoubleApply, err := db.GetEnabledRoutes(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1)
|
||||
|
||||
err = app.db.enableRoutes(&machine, "150.0.10.0/25")
|
||||
err = db.enableRoutes(&machine, "150.0.10.0/25")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutesWithAdditionalRoute, err := app.db.GetEnabledRoutes(&machine)
|
||||
enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&machine)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(3))
|
||||
}
|
||||
|
||||
func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
||||
user, err := app.db.CreateUser("test")
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetMachine("test", "test_enable_route_machine")
|
||||
_, err = db.GetMachine("test", "test_enable_route_machine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
route, err := netip.ParsePrefix(
|
||||
@ -152,75 +158,77 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
||||
hostInfo1 := tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{route, route2},
|
||||
}
|
||||
machine1 := Machine{
|
||||
machine1 := types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "test_enable_route_machine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: HostInfo(hostInfo1),
|
||||
HostInfo: types.HostInfo(hostInfo1),
|
||||
}
|
||||
app.db.db.Save(&machine1)
|
||||
db.db.Save(&machine1)
|
||||
|
||||
err = app.db.processMachineRoutes(&machine1)
|
||||
err = db.ProcessMachineRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = app.db.enableRoutes(&machine1, route.String())
|
||||
err = db.enableRoutes(&machine1, route.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = app.db.enableRoutes(&machine1, route2.String())
|
||||
err = db.enableRoutes(&machine1, route2.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
hostInfo2 := tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{route2},
|
||||
}
|
||||
machine2 := Machine{
|
||||
machine2 := types.Machine{
|
||||
ID: 2,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "test_enable_route_machine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: HostInfo(hostInfo2),
|
||||
HostInfo: types.HostInfo(hostInfo2),
|
||||
}
|
||||
app.db.db.Save(&machine2)
|
||||
db.db.Save(&machine2)
|
||||
|
||||
err = app.db.processMachineRoutes(&machine2)
|
||||
err = db.ProcessMachineRoutes(&machine2)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = app.db.enableRoutes(&machine2, route2.String())
|
||||
err = db.enableRoutes(&machine2, route2.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1)
|
||||
enabledRoutes1, err := db.GetEnabledRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes1), check.Equals, 2)
|
||||
|
||||
enabledRoutes2, err := app.db.GetEnabledRoutes(&machine2)
|
||||
enabledRoutes2, err := db.GetEnabledRoutes(&machine2)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes2), check.Equals, 1)
|
||||
|
||||
routes, err := app.db.getMachinePrimaryRoutes(&machine1)
|
||||
routes, err := db.GetMachinePrimaryRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(routes), check.Equals, 2)
|
||||
|
||||
routes, err = app.db.getMachinePrimaryRoutes(&machine2)
|
||||
routes, err = db.GetMachinePrimaryRoutes(&machine2)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(routes), check.Equals, 0)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(3))
|
||||
}
|
||||
|
||||
func (s *Suite) TestSubnetFailover(c *check.C) {
|
||||
user, err := app.db.CreateUser("test")
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetMachine("test", "test_enable_route_machine")
|
||||
_, err = db.GetMachine("test", "test_enable_route_machine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
prefix, err := netip.ParsePrefix(
|
||||
@ -238,134 +246,136 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
machine1 := Machine{
|
||||
machine1 := types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "test_enable_route_machine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: HostInfo(hostInfo1),
|
||||
HostInfo: types.HostInfo(hostInfo1),
|
||||
LastSeen: &now,
|
||||
}
|
||||
app.db.db.Save(&machine1)
|
||||
db.db.Save(&machine1)
|
||||
|
||||
err = app.db.processMachineRoutes(&machine1)
|
||||
err = db.ProcessMachineRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = app.db.enableRoutes(&machine1, prefix.String())
|
||||
err = db.enableRoutes(&machine1, prefix.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = app.db.enableRoutes(&machine1, prefix2.String())
|
||||
err = db.enableRoutes(&machine1, prefix2.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = app.db.handlePrimarySubnetFailover()
|
||||
err = db.HandlePrimarySubnetFailover()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1)
|
||||
enabledRoutes1, err := db.GetEnabledRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes1), check.Equals, 2)
|
||||
|
||||
route, err := app.db.getPrimaryRoute(prefix)
|
||||
route, err := db.getPrimaryRoute(prefix)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(route.MachineID, check.Equals, machine1.ID)
|
||||
|
||||
hostInfo2 := tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{prefix2},
|
||||
}
|
||||
machine2 := Machine{
|
||||
machine2 := types.Machine{
|
||||
ID: 2,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "test_enable_route_machine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: HostInfo(hostInfo2),
|
||||
HostInfo: types.HostInfo(hostInfo2),
|
||||
LastSeen: &now,
|
||||
}
|
||||
app.db.db.Save(&machine2)
|
||||
db.db.Save(&machine2)
|
||||
|
||||
err = app.db.processMachineRoutes(&machine2)
|
||||
err = db.ProcessMachineRoutes(&machine2)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = app.db.enableRoutes(&machine2, prefix2.String())
|
||||
err = db.enableRoutes(&machine2, prefix2.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = app.db.handlePrimarySubnetFailover()
|
||||
err = db.HandlePrimarySubnetFailover()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes1, err = app.db.GetEnabledRoutes(&machine1)
|
||||
enabledRoutes1, err = db.GetEnabledRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes1), check.Equals, 2)
|
||||
|
||||
enabledRoutes2, err := app.db.GetEnabledRoutes(&machine2)
|
||||
enabledRoutes2, err := db.GetEnabledRoutes(&machine2)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes2), check.Equals, 1)
|
||||
|
||||
routes, err := app.db.getMachinePrimaryRoutes(&machine1)
|
||||
routes, err := db.GetMachinePrimaryRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(routes), check.Equals, 2)
|
||||
|
||||
routes, err = app.db.getMachinePrimaryRoutes(&machine2)
|
||||
routes, err = db.GetMachinePrimaryRoutes(&machine2)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(routes), check.Equals, 0)
|
||||
|
||||
// lets make machine1 lastseen 10 mins ago
|
||||
before := now.Add(-10 * time.Minute)
|
||||
machine1.LastSeen = &before
|
||||
err = app.db.db.Save(&machine1).Error
|
||||
err = db.db.Save(&machine1).Error
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = app.db.handlePrimarySubnetFailover()
|
||||
err = db.HandlePrimarySubnetFailover()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
routes, err = app.db.getMachinePrimaryRoutes(&machine1)
|
||||
routes, err = db.GetMachinePrimaryRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(routes), check.Equals, 1)
|
||||
|
||||
routes, err = app.db.getMachinePrimaryRoutes(&machine2)
|
||||
routes, err = db.GetMachinePrimaryRoutes(&machine2)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(routes), check.Equals, 1)
|
||||
|
||||
machine2.HostInfo = HostInfo(tailcfg.Hostinfo{
|
||||
machine2.HostInfo = types.HostInfo(tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{prefix, prefix2},
|
||||
})
|
||||
err = app.db.db.Save(&machine2).Error
|
||||
err = db.db.Save(&machine2).Error
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = app.db.processMachineRoutes(&machine2)
|
||||
err = db.ProcessMachineRoutes(&machine2)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = app.db.enableRoutes(&machine2, prefix.String())
|
||||
err = db.enableRoutes(&machine2, prefix.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = app.db.handlePrimarySubnetFailover()
|
||||
err = db.HandlePrimarySubnetFailover()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
routes, err = app.db.getMachinePrimaryRoutes(&machine1)
|
||||
routes, err = db.GetMachinePrimaryRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(routes), check.Equals, 0)
|
||||
|
||||
routes, err = app.db.getMachinePrimaryRoutes(&machine2)
|
||||
routes, err = db.GetMachinePrimaryRoutes(&machine2)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(routes), check.Equals, 2)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(6))
|
||||
}
|
||||
|
||||
// TestAllowedIPRoutes tests that the AllowedIPs are correctly set for a node,
|
||||
// including both the primary routes the node is responsible for, and the
|
||||
// exit node routes if enabled.
|
||||
func (s *Suite) TestAllowedIPRoutes(c *check.C) {
|
||||
user, err := app.db.CreateUser("test")
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetMachine("test", "test_enable_route_machine")
|
||||
_, err = db.GetMachine("test", "test_enable_route_machine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
prefix, err := netip.ParsePrefix(
|
||||
@ -397,35 +407,35 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) {
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
now := time.Now()
|
||||
machine1 := Machine{
|
||||
machine1 := types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: util.MachinePublicKeyStripPrefix(machineKey.Public()),
|
||||
NodeKey: util.NodePublicKeyStripPrefix(nodeKey.Public()),
|
||||
DiscoKey: util.DiscoPublicKeyStripPrefix(discoKey.Public()),
|
||||
Hostname: "test_enable_route_machine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: HostInfo(hostInfo1),
|
||||
HostInfo: types.HostInfo(hostInfo1),
|
||||
LastSeen: &now,
|
||||
}
|
||||
app.db.db.Save(&machine1)
|
||||
db.db.Save(&machine1)
|
||||
|
||||
err = app.db.processMachineRoutes(&machine1)
|
||||
err = db.ProcessMachineRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = app.db.enableRoutes(&machine1, prefix.String())
|
||||
err = db.enableRoutes(&machine1, prefix.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
// We do not enable this one on purpose to test that it is not enabled
|
||||
// err = app.db.enableRoutes(&machine1, prefix2.String())
|
||||
// err = db.enableRoutes(&machine1, prefix2.String())
|
||||
// c.Assert(err, check.IsNil)
|
||||
|
||||
routes, err := app.db.GetMachineRoutes(&machine1)
|
||||
routes, err := db.GetMachineRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
for _, route := range routes {
|
||||
if route.isExitRoute() {
|
||||
err = app.db.EnableRoute(uint64(route.ID))
|
||||
if route.IsExitRoute() {
|
||||
err = db.EnableRoute(uint64(route.ID))
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
// We only enable one exit route, so we can test that both are enabled
|
||||
@ -433,14 +443,14 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) {
|
||||
}
|
||||
}
|
||||
|
||||
err = app.db.handlePrimarySubnetFailover()
|
||||
err = db.HandlePrimarySubnetFailover()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1)
|
||||
enabledRoutes1, err := db.GetEnabledRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes1), check.Equals, 3)
|
||||
|
||||
peer, err := app.db.toNode(machine1, app.aclPolicy, "headscale.net", nil)
|
||||
peer, err := db.TailNode(machine1, &policy.ACLPolicy{}, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(peer.AllowedIPs), check.Equals, 3)
|
||||
@ -461,44 +471,46 @@ func (s *Suite) TestAllowedIPRoutes(c *check.C) {
|
||||
|
||||
// Now we disable only one of the exit routes
|
||||
// and we see if both are disabled
|
||||
var exitRouteV4 Route
|
||||
var exitRouteV4 types.Route
|
||||
for _, route := range routes {
|
||||
if route.isExitRoute() && netip.Prefix(route.Prefix) == prefixExitNodeV4 {
|
||||
if route.IsExitRoute() && netip.Prefix(route.Prefix) == prefixExitNodeV4 {
|
||||
exitRouteV4 = route
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
err = app.db.DisableRoute(uint64(exitRouteV4.ID))
|
||||
err = db.DisableRoute(uint64(exitRouteV4.ID))
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes1, err = app.db.GetEnabledRoutes(&machine1)
|
||||
enabledRoutes1, err = db.GetEnabledRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes1), check.Equals, 1)
|
||||
|
||||
// and now we delete only one of the exit routes
|
||||
// and we check if both are deleted
|
||||
routes, err = app.db.GetMachineRoutes(&machine1)
|
||||
routes, err = db.GetMachineRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(routes), check.Equals, 4)
|
||||
|
||||
err = app.db.DeleteRoute(uint64(exitRouteV4.ID))
|
||||
err = db.DeleteRoute(uint64(exitRouteV4.ID))
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
routes, err = app.db.GetMachineRoutes(&machine1)
|
||||
routes, err = db.GetMachineRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(routes), check.Equals, 2)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(2))
|
||||
}
|
||||
|
||||
func (s *Suite) TestDeleteRoutes(c *check.C) {
|
||||
user, err := app.db.CreateUser("test")
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetMachine("test", "test_enable_route_machine")
|
||||
_, err = db.GetMachine("test", "test_enable_route_machine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
prefix, err := netip.ParsePrefix(
|
||||
@ -516,36 +528,38 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
machine1 := Machine{
|
||||
machine1 := types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "test_enable_route_machine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
HostInfo: HostInfo(hostInfo1),
|
||||
HostInfo: types.HostInfo(hostInfo1),
|
||||
LastSeen: &now,
|
||||
}
|
||||
app.db.db.Save(&machine1)
|
||||
db.db.Save(&machine1)
|
||||
|
||||
err = app.db.processMachineRoutes(&machine1)
|
||||
err = db.ProcessMachineRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = app.db.enableRoutes(&machine1, prefix.String())
|
||||
err = db.enableRoutes(&machine1, prefix.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = app.db.enableRoutes(&machine1, prefix2.String())
|
||||
err = db.enableRoutes(&machine1, prefix2.String())
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
routes, err := app.db.GetMachineRoutes(&machine1)
|
||||
routes, err := db.GetMachineRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = app.db.DeleteRoute(uint64(routes[0].ID))
|
||||
err = db.DeleteRoute(uint64(routes[0].ID))
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
enabledRoutes1, err := app.db.GetEnabledRoutes(&machine1)
|
||||
enabledRoutes1, err := db.GetEnabledRoutes(&machine1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(enabledRoutes1), check.Equals, 1)
|
||||
|
||||
c.Assert(channelUpdates, check.Equals, int32(2))
|
||||
}
|
74
hscontrol/db/suite_test.go
Normal file
74
hscontrol/db/suite_test.go
Normal file
@ -0,0 +1,74 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"gopkg.in/check.v1"
|
||||
)
|
||||
|
||||
func Test(t *testing.T) {
|
||||
check.TestingT(t)
|
||||
}
|
||||
|
||||
var _ = check.Suite(&Suite{})
|
||||
|
||||
type Suite struct{}
|
||||
|
||||
var (
|
||||
tmpDir string
|
||||
db *HSDatabase
|
||||
|
||||
// channelUpdates counts the number of times
|
||||
// either of the channels was notified.
|
||||
channelUpdates int32
|
||||
)
|
||||
|
||||
func (s *Suite) SetUpTest(c *check.C) {
|
||||
atomic.StoreInt32(&channelUpdates, 0)
|
||||
s.ResetDB(c)
|
||||
}
|
||||
|
||||
func (s *Suite) TearDownTest(c *check.C) {
|
||||
os.RemoveAll(tmpDir)
|
||||
}
|
||||
|
||||
func notificationSink(c <-chan struct{}) {
|
||||
for {
|
||||
<-c
|
||||
atomic.AddInt32(&channelUpdates, 1)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Suite) ResetDB(c *check.C) {
|
||||
if len(tmpDir) != 0 {
|
||||
os.RemoveAll(tmpDir)
|
||||
}
|
||||
var err error
|
||||
tmpDir, err = os.MkdirTemp("", "autoygg-client-test")
|
||||
if err != nil {
|
||||
c.Fatal(err)
|
||||
}
|
||||
|
||||
sink := make(chan struct{})
|
||||
|
||||
go notificationSink(sink)
|
||||
|
||||
db, err = NewHeadscaleDatabase(
|
||||
"sqlite3",
|
||||
tmpDir+"/headscale_test.db",
|
||||
false,
|
||||
false,
|
||||
sink,
|
||||
sink,
|
||||
[]netip.Prefix{
|
||||
netip.MustParsePrefix("10.27.0.0/23"),
|
||||
},
|
||||
"",
|
||||
)
|
||||
if err != nil {
|
||||
c.Fatal(err)
|
||||
}
|
||||
}
|
@ -1,17 +1,12 @@
|
||||
package hscontrol
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
@ -20,33 +15,16 @@ var (
|
||||
ErrUserExists = errors.New("user already exists")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
|
||||
ErrInvalidUserName = errors.New("invalid user name")
|
||||
)
|
||||
|
||||
const (
|
||||
// value related to RFC 1123 and 952.
|
||||
labelHostnameLength = 63
|
||||
)
|
||||
|
||||
var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
|
||||
|
||||
// User is the way Headscale implements the concept of users in Tailscale
|
||||
//
|
||||
// At the end of the day, users in Tailscale are some kind of 'bubbles' or users
|
||||
// that contain our machines.
|
||||
type User struct {
|
||||
gorm.Model
|
||||
Name string `gorm:"unique"`
|
||||
}
|
||||
|
||||
// CreateUser creates a new User. Returns error if could not be created
|
||||
// or another user already exists.
|
||||
func (hsdb *HSDatabase) CreateUser(name string) (*User, error) {
|
||||
err := CheckForFQDNRules(name)
|
||||
func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) {
|
||||
err := util.CheckForFQDNRules(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user := User{}
|
||||
user := types.User{}
|
||||
if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil {
|
||||
return nil, ErrUserExists
|
||||
}
|
||||
@ -105,7 +83,7 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = CheckForFQDNRules(newName)
|
||||
err = util.CheckForFQDNRules(newName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -127,8 +105,8 @@ func (hsdb *HSDatabase) RenameUser(oldName, newName string) error {
|
||||
}
|
||||
|
||||
// GetUser fetches a user by name.
|
||||
func (hsdb *HSDatabase) GetUser(name string) (*User, error) {
|
||||
user := User{}
|
||||
func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) {
|
||||
user := types.User{}
|
||||
if result := hsdb.db.First(&user, "name = ?", name); errors.Is(
|
||||
result.Error,
|
||||
gorm.ErrRecordNotFound,
|
||||
@ -140,8 +118,8 @@ func (hsdb *HSDatabase) GetUser(name string) (*User, error) {
|
||||
}
|
||||
|
||||
// ListUsers gets all the existing users.
|
||||
func (hsdb *HSDatabase) ListUsers() ([]User, error) {
|
||||
users := []User{}
|
||||
func (hsdb *HSDatabase) ListUsers() ([]types.User, error) {
|
||||
users := []types.User{}
|
||||
if err := hsdb.db.Find(&users).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -150,8 +128,8 @@ func (hsdb *HSDatabase) ListUsers() ([]User, error) {
|
||||
}
|
||||
|
||||
// ListMachinesByUser gets all the nodes in a given user.
|
||||
func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) {
|
||||
err := CheckForFQDNRules(name)
|
||||
func (hsdb *HSDatabase) ListMachinesByUser(name string) (types.Machines, error) {
|
||||
err := util.CheckForFQDNRules(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -160,8 +138,8 @@ func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
machines := []Machine{}
|
||||
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&Machine{UserID: user.ID}).Find(&machines).Error; err != nil {
|
||||
machines := types.Machines{}
|
||||
if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Machine{UserID: user.ID}).Find(&machines).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -169,8 +147,8 @@ func (hsdb *HSDatabase) ListMachinesByUser(name string) ([]Machine, error) {
|
||||
}
|
||||
|
||||
// SetMachineUser assigns a Machine to a user.
|
||||
func (hsdb *HSDatabase) SetMachineUser(machine *Machine, username string) error {
|
||||
err := CheckForFQDNRules(username)
|
||||
func (hsdb *HSDatabase) SetMachineUser(machine *types.Machine, username string) error {
|
||||
err := util.CheckForFQDNRules(username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -186,37 +164,11 @@ func (hsdb *HSDatabase) SetMachineUser(machine *Machine, username string) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *User) toTailscaleUser() *tailcfg.User {
|
||||
user := tailcfg.User{
|
||||
ID: tailcfg.UserID(n.ID),
|
||||
LoginName: n.Name,
|
||||
DisplayName: n.Name,
|
||||
ProfilePicURL: "",
|
||||
Domain: "headscale.net",
|
||||
Logins: []tailcfg.LoginID{},
|
||||
Created: time.Time{},
|
||||
}
|
||||
|
||||
return &user
|
||||
}
|
||||
|
||||
func (n *User) toTailscaleLogin() *tailcfg.Login {
|
||||
login := tailcfg.Login{
|
||||
ID: tailcfg.LoginID(n.ID),
|
||||
LoginName: n.Name,
|
||||
DisplayName: n.Name,
|
||||
ProfilePicURL: "",
|
||||
Domain: "headscale.net",
|
||||
}
|
||||
|
||||
return &login
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) getMapResponseUserProfiles(
|
||||
machine Machine,
|
||||
peers Machines,
|
||||
func (hsdb *HSDatabase) GetMapResponseUserProfiles(
|
||||
machine types.Machine,
|
||||
peers types.Machines,
|
||||
) []tailcfg.UserProfile {
|
||||
userMap := make(map[string]User)
|
||||
userMap := make(map[string]types.User)
|
||||
userMap[machine.User.Name] = machine.User
|
||||
for _, peer := range peers {
|
||||
userMap[peer.User.Name] = peer.User // not worth checking if already is there
|
||||
@ -240,63 +192,3 @@ func (hsdb *HSDatabase) getMapResponseUserProfiles(
|
||||
|
||||
return profiles
|
||||
}
|
||||
|
||||
func (n *User) toProto() *v1.User {
|
||||
return &v1.User{
|
||||
Id: strconv.FormatUint(uint64(n.ID), util.Base10),
|
||||
Name: n.Name,
|
||||
CreatedAt: timestamppb.New(n.CreatedAt),
|
||||
}
|
||||
}
|
||||
|
||||
// NormalizeToFQDNRules will replace forbidden chars in user
|
||||
// it can also return an error if the user doesn't respect RFC 952 and 1123.
|
||||
func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) {
|
||||
name = strings.ToLower(name)
|
||||
name = strings.ReplaceAll(name, "'", "")
|
||||
atIdx := strings.Index(name, "@")
|
||||
if stripEmailDomain && atIdx > 0 {
|
||||
name = name[:atIdx]
|
||||
} else {
|
||||
name = strings.ReplaceAll(name, "@", ".")
|
||||
}
|
||||
name = invalidCharsInUserRegex.ReplaceAllString(name, "-")
|
||||
|
||||
for _, elt := range strings.Split(name, ".") {
|
||||
if len(elt) > labelHostnameLength {
|
||||
return "", fmt.Errorf(
|
||||
"label %v is more than 63 chars: %w",
|
||||
elt,
|
||||
ErrInvalidUserName,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func CheckForFQDNRules(name string) error {
|
||||
if len(name) > labelHostnameLength {
|
||||
return fmt.Errorf(
|
||||
"DNS segment must not be over 63 chars. %v doesn't comply with this rule: %w",
|
||||
name,
|
||||
ErrInvalidUserName,
|
||||
)
|
||||
}
|
||||
if strings.ToLower(name) != name {
|
||||
return fmt.Errorf(
|
||||
"DNS segment should be lowercase. %v doesn't comply with this rule: %w",
|
||||
name,
|
||||
ErrInvalidUserName,
|
||||
)
|
||||
}
|
||||
if invalidCharsInUserRegex.MatchString(name) {
|
||||
return fmt.Errorf(
|
||||
"DNS segment should only be composed of lowercase ASCII letters numbers, hyphen and dots. %v doesn't comply with theses rules: %w",
|
||||
name,
|
||||
ErrInvalidUserName,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
277
hscontrol/db/users_test.go
Normal file
277
hscontrol/db/users_test.go
Normal file
@ -0,0 +1,277 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"gopkg.in/check.v1"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(user.Name, check.Equals, "test")
|
||||
|
||||
users, err := db.ListUsers()
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(users), check.Equals, 1)
|
||||
|
||||
err = db.DestroyUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetUser("test")
|
||||
c.Assert(err, check.NotNil)
|
||||
}
|
||||
|
||||
func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
||||
err := db.DestroyUser("test")
|
||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||
|
||||
user, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.DestroyUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
result := db.db.Preload("User").First(&pak, "key = ?", pak.Key)
|
||||
// destroying a user also deletes all associated preauthkeys
|
||||
c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound)
|
||||
|
||||
user, err = db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err = db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
|
||||
err = db.DestroyUser("test")
|
||||
c.Assert(err, check.Equals, ErrUserStillHasNodes)
|
||||
}
|
||||
|
||||
func (s *Suite) TestRenameUser(c *check.C) {
|
||||
userTest, err := db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(userTest.Name, check.Equals, "test")
|
||||
|
||||
users, err := db.ListUsers()
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(users), check.Equals, 1)
|
||||
|
||||
err = db.RenameUser("test", "test-renamed")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetUser("test")
|
||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||
|
||||
_, err = db.GetUser("test-renamed")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = db.RenameUser("test-does-not-exit", "test")
|
||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||
|
||||
userTest2, err := db.CreateUser("test2")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(userTest2.Name, check.Equals, "test2")
|
||||
|
||||
err = db.RenameUser("test2", "test-renamed")
|
||||
c.Assert(err, check.Equals, ErrUserExists)
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
||||
userShared1, err := db.CreateUser("shared1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
userShared2, err := db.CreateUser("shared2")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
userShared3, err := db.CreateUser("shared3")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
preAuthKeyShared1, err := db.CreatePreAuthKey(
|
||||
userShared1.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
preAuthKeyShared2, err := db.CreatePreAuthKey(
|
||||
userShared2.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
preAuthKeyShared3, err := db.CreatePreAuthKey(
|
||||
userShared3.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
preAuthKey2Shared1, err := db.CreatePreAuthKey(
|
||||
userShared1.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
machineInShared1 := &types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
Hostname: "test_get_shared_nodes_1",
|
||||
UserID: userShared1.ID,
|
||||
User: *userShared1,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
||||
AuthKeyID: uint(preAuthKeyShared1.ID),
|
||||
}
|
||||
db.db.Save(machineInShared1)
|
||||
|
||||
_, err = db.GetMachine(userShared1.Name, machineInShared1.Hostname)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machineInShared2 := &types.Machine{
|
||||
ID: 2,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Hostname: "test_get_shared_nodes_2",
|
||||
UserID: userShared2.ID,
|
||||
User: *userShared2,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
||||
AuthKeyID: uint(preAuthKeyShared2.ID),
|
||||
}
|
||||
db.db.Save(machineInShared2)
|
||||
|
||||
_, err = db.GetMachine(userShared2.Name, machineInShared2.Hostname)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machineInShared3 := &types.Machine{
|
||||
ID: 3,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Hostname: "test_get_shared_nodes_3",
|
||||
UserID: userShared3.ID,
|
||||
User: *userShared3,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
||||
AuthKeyID: uint(preAuthKeyShared3.ID),
|
||||
}
|
||||
db.db.Save(machineInShared3)
|
||||
|
||||
_, err = db.GetMachine(userShared3.Name, machineInShared3.Hostname)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machine2InShared1 := &types.Machine{
|
||||
ID: 4,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Hostname: "test_get_shared_nodes_4",
|
||||
UserID: userShared1.ID,
|
||||
User: *userShared1,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
|
||||
AuthKeyID: uint(preAuthKey2Shared1.ID),
|
||||
}
|
||||
db.db.Save(machine2InShared1)
|
||||
|
||||
peersOfMachine1InShared1, err := db.getPeers([]tailcfg.FilterRule{}, machineInShared1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
userProfiles := db.GetMapResponseUserProfiles(
|
||||
*machineInShared1,
|
||||
peersOfMachine1InShared1,
|
||||
)
|
||||
|
||||
c.Assert(len(userProfiles), check.Equals, 3)
|
||||
|
||||
found := false
|
||||
for _, userProfiles := range userProfiles {
|
||||
if userProfiles.DisplayName == userShared1.Name {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
c.Assert(found, check.Equals, true)
|
||||
|
||||
found = false
|
||||
for _, userProfile := range userProfiles {
|
||||
if userProfile.DisplayName == userShared2.Name {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
c.Assert(found, check.Equals, true)
|
||||
}
|
||||
|
||||
func (s *Suite) TestSetMachineUser(c *check.C) {
|
||||
oldUser, err := db.CreateUser("old")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
newUser, err := db.CreateUser("new")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machine := types.Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: oldUser.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
db.db.Save(&machine)
|
||||
c.Assert(machine.UserID, check.Equals, oldUser.ID)
|
||||
|
||||
err = db.SetMachineUser(&machine, newUser.Name)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(machine.UserID, check.Equals, newUser.ID)
|
||||
c.Assert(machine.User.Name, check.Equals, newUser.Name)
|
||||
|
||||
err = db.SetMachineUser(&machine, "non-existing-user")
|
||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||
|
||||
err = db.SetMachineUser(&machine, newUser.Name)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(machine.UserID, check.Equals, newUser.ID)
|
||||
c.Assert(machine.User.Name, check.Equals, newUser.Name)
|
||||
}
|
@ -7,6 +7,7 @@ import (
|
||||
"strings"
|
||||
|
||||
mapset "github.com/deckarep/golang-set/v2"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
@ -165,7 +166,7 @@ func generateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
|
||||
//
|
||||
// This will produce a resolver like:
|
||||
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
|
||||
func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine Machine) {
|
||||
func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine types.Machine) {
|
||||
for _, resolver := range resolvers {
|
||||
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
|
||||
attrs := url.Values{
|
||||
@ -185,8 +186,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine Machine) {
|
||||
func getMapResponseDNSConfig(
|
||||
dnsConfigOrig *tailcfg.DNSConfig,
|
||||
baseDomain string,
|
||||
machine Machine,
|
||||
peers Machines,
|
||||
machine types.Machine,
|
||||
peers types.Machines,
|
||||
) *tailcfg.DNSConfig {
|
||||
var dnsConfig *tailcfg.DNSConfig = dnsConfigOrig.Clone()
|
||||
if dnsConfigOrig != nil && dnsConfigOrig.Proxied { // if MagicDNS is enabled
|
||||
@ -200,7 +201,7 @@ func getMapResponseDNSConfig(
|
||||
),
|
||||
)
|
||||
|
||||
userSet := mapset.NewSet[User]()
|
||||
userSet := mapset.NewSet[types.User]()
|
||||
userSet.Add(machine.User)
|
||||
for _, p := range peers {
|
||||
userSet.Add(p.User)
|
||||
|
@ -4,6 +4,8 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"gopkg.in/check.v1"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
@ -160,7 +162,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||
_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
machineInShared1 := &Machine{
|
||||
machineInShared1 := &types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
@ -168,16 +170,17 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||
Hostname: "test_get_shared_nodes_1",
|
||||
UserID: userShared1.ID,
|
||||
User: *userShared1,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
||||
AuthKeyID: uint(preAuthKeyInShared1.ID),
|
||||
}
|
||||
app.db.db.Save(machineInShared1)
|
||||
err = app.db.MachineSave(machineInShared1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machineInShared2 := &Machine{
|
||||
machineInShared2 := &types.Machine{
|
||||
ID: 2,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
@ -185,16 +188,17 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||
Hostname: "test_get_shared_nodes_2",
|
||||
UserID: userShared2.ID,
|
||||
User: *userShared2,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
||||
AuthKeyID: uint(preAuthKeyInShared2.ID),
|
||||
}
|
||||
app.db.db.Save(machineInShared2)
|
||||
err = app.db.MachineSave(machineInShared2)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machineInShared3 := &Machine{
|
||||
machineInShared3 := &types.Machine{
|
||||
ID: 3,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
@ -202,16 +206,17 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||
Hostname: "test_get_shared_nodes_3",
|
||||
UserID: userShared3.ID,
|
||||
User: *userShared3,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
||||
AuthKeyID: uint(preAuthKeyInShared3.ID),
|
||||
}
|
||||
app.db.db.Save(machineInShared3)
|
||||
err = app.db.MachineSave(machineInShared3)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machine2InShared1 := &Machine{
|
||||
machine2InShared1 := &types.Machine{
|
||||
ID: 4,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
@ -219,11 +224,12 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||
Hostname: "test_get_shared_nodes_4",
|
||||
UserID: userShared1.ID,
|
||||
User: *userShared1,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
|
||||
AuthKeyID: uint(PreAuthKey2InShared1.ID),
|
||||
}
|
||||
app.db.db.Save(machine2InShared1)
|
||||
err = app.db.MachineSave(machine2InShared1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
baseDomain := "foobar.headscale.net"
|
||||
dnsConfigOrig := tailcfg.DNSConfig{
|
||||
@ -232,7 +238,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
|
||||
Proxied: true,
|
||||
}
|
||||
|
||||
peersOfMachineInShared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1)
|
||||
peersOfMachineInShared1, err := app.db.GetValidPeers(app.aclRules, machineInShared1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
dnsConfig := getMapResponseDNSConfig(
|
||||
@ -307,7 +313,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||
_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
machineInShared1 := &Machine{
|
||||
machineInShared1 := &types.Machine{
|
||||
ID: 1,
|
||||
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
@ -315,16 +321,17 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||
Hostname: "test_get_shared_nodes_1",
|
||||
UserID: userShared1.ID,
|
||||
User: *userShared1,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
||||
AuthKeyID: uint(preAuthKeyInShared1.ID),
|
||||
}
|
||||
app.db.db.Save(machineInShared1)
|
||||
err = app.db.MachineSave(machineInShared1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machineInShared2 := &Machine{
|
||||
machineInShared2 := &types.Machine{
|
||||
ID: 2,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
@ -332,16 +339,17 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||
Hostname: "test_get_shared_nodes_2",
|
||||
UserID: userShared2.ID,
|
||||
User: *userShared2,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
||||
AuthKeyID: uint(preAuthKeyInShared2.ID),
|
||||
}
|
||||
app.db.db.Save(machineInShared2)
|
||||
err = app.db.MachineSave(machineInShared2)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machineInShared3 := &Machine{
|
||||
machineInShared3 := &types.Machine{
|
||||
ID: 3,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
@ -349,16 +357,17 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||
Hostname: "test_get_shared_nodes_3",
|
||||
UserID: userShared3.ID,
|
||||
User: *userShared3,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
||||
AuthKeyID: uint(preAuthKeyInShared3.ID),
|
||||
}
|
||||
app.db.db.Save(machineInShared3)
|
||||
err = app.db.MachineSave(machineInShared3)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machine2InShared1 := &Machine{
|
||||
machine2InShared1 := &types.Machine{
|
||||
ID: 4,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
@ -366,11 +375,12 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||
Hostname: "test_get_shared_nodes_4",
|
||||
UserID: userShared1.ID,
|
||||
User: *userShared1,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
|
||||
AuthKeyID: uint(preAuthKey2InShared1.ID),
|
||||
}
|
||||
app.db.db.Save(machine2InShared1)
|
||||
err = app.db.MachineSave(machine2InShared1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
baseDomain := "foobar.headscale.net"
|
||||
dnsConfigOrig := tailcfg.DNSConfig{
|
||||
@ -379,7 +389,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
|
||||
Proxied: false,
|
||||
}
|
||||
|
||||
peersOfMachine1Shared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1)
|
||||
peersOfMachine1Shared1, err := app.db.GetValidPeers(app.aclRules, machineInShared1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
dnsConfig := getMapResponseDNSConfig(
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"google.golang.org/grpc/codes"
|
||||
@ -36,7 +37,7 @@ func (api headscaleV1APIServer) GetUser(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &v1.GetUserResponse{User: user.toProto()}, nil
|
||||
return &v1.GetUserResponse{User: user.Proto()}, nil
|
||||
}
|
||||
|
||||
func (api headscaleV1APIServer) CreateUser(
|
||||
@ -48,7 +49,7 @@ func (api headscaleV1APIServer) CreateUser(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &v1.CreateUserResponse{User: user.toProto()}, nil
|
||||
return &v1.CreateUserResponse{User: user.Proto()}, nil
|
||||
}
|
||||
|
||||
func (api headscaleV1APIServer) RenameUser(
|
||||
@ -65,7 +66,7 @@ func (api headscaleV1APIServer) RenameUser(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &v1.RenameUserResponse{User: user.toProto()}, nil
|
||||
return &v1.RenameUserResponse{User: user.Proto()}, nil
|
||||
}
|
||||
|
||||
func (api headscaleV1APIServer) DeleteUser(
|
||||
@ -91,7 +92,7 @@ func (api headscaleV1APIServer) ListUsers(
|
||||
|
||||
response := make([]*v1.User, len(users))
|
||||
for index, user := range users {
|
||||
response[index] = user.toProto()
|
||||
response[index] = user.Proto()
|
||||
}
|
||||
|
||||
log.Trace().Caller().Interface("users", response).Msg("")
|
||||
@ -128,7 +129,7 @@ func (api headscaleV1APIServer) CreatePreAuthKey(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &v1.CreatePreAuthKeyResponse{PreAuthKey: preAuthKey.toProto()}, nil
|
||||
return &v1.CreatePreAuthKeyResponse{PreAuthKey: preAuthKey.Proto()}, nil
|
||||
}
|
||||
|
||||
func (api headscaleV1APIServer) ExpirePreAuthKey(
|
||||
@ -159,7 +160,7 @@ func (api headscaleV1APIServer) ListPreAuthKeys(
|
||||
|
||||
response := make([]*v1.PreAuthKey, len(preAuthKeys))
|
||||
for index, key := range preAuthKeys {
|
||||
response[index] = key.toProto()
|
||||
response[index] = key.Proto()
|
||||
}
|
||||
|
||||
return &v1.ListPreAuthKeysResponse{PreAuthKeys: response}, nil
|
||||
@ -179,13 +180,13 @@ func (api headscaleV1APIServer) RegisterMachine(
|
||||
request.GetKey(),
|
||||
request.GetUser(),
|
||||
nil,
|
||||
RegisterMethodCLI,
|
||||
util.RegisterMethodCLI,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &v1.RegisterMachineResponse{Machine: machine.toProto()}, nil
|
||||
return &v1.RegisterMachineResponse{Machine: machine.Proto()}, nil
|
||||
}
|
||||
|
||||
func (api headscaleV1APIServer) GetMachine(
|
||||
@ -197,7 +198,7 @@ func (api headscaleV1APIServer) GetMachine(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &v1.GetMachineResponse{Machine: machine.toProto()}, nil
|
||||
return &v1.GetMachineResponse{Machine: machine.Proto()}, nil
|
||||
}
|
||||
|
||||
func (api headscaleV1APIServer) SetTags(
|
||||
@ -218,7 +219,7 @@ func (api headscaleV1APIServer) SetTags(
|
||||
}
|
||||
}
|
||||
|
||||
err = api.h.db.SetTags(machine, request.GetTags(), api.h.UpdateACLRules)
|
||||
err = api.h.db.SetTags(machine, request.GetTags())
|
||||
if err != nil {
|
||||
return &v1.SetTagsResponse{
|
||||
Machine: nil,
|
||||
@ -230,7 +231,7 @@ func (api headscaleV1APIServer) SetTags(
|
||||
Strs("tags", request.GetTags()).
|
||||
Msg("Changing tags of machine")
|
||||
|
||||
return &v1.SetTagsResponse{Machine: machine.toProto()}, nil
|
||||
return &v1.SetTagsResponse{Machine: machine.Proto()}, nil
|
||||
}
|
||||
|
||||
func validateTag(tag string) error {
|
||||
@ -283,7 +284,7 @@ func (api headscaleV1APIServer) ExpireMachine(
|
||||
Time("expiry", *machine.Expiry).
|
||||
Msg("machine expired")
|
||||
|
||||
return &v1.ExpireMachineResponse{Machine: machine.toProto()}, nil
|
||||
return &v1.ExpireMachineResponse{Machine: machine.Proto()}, nil
|
||||
}
|
||||
|
||||
func (api headscaleV1APIServer) RenameMachine(
|
||||
@ -308,7 +309,7 @@ func (api headscaleV1APIServer) RenameMachine(
|
||||
Str("new_name", request.GetNewName()).
|
||||
Msg("machine renamed")
|
||||
|
||||
return &v1.RenameMachineResponse{Machine: machine.toProto()}, nil
|
||||
return &v1.RenameMachineResponse{Machine: machine.Proto()}, nil
|
||||
}
|
||||
|
||||
func (api headscaleV1APIServer) ListMachines(
|
||||
@ -323,7 +324,7 @@ func (api headscaleV1APIServer) ListMachines(
|
||||
|
||||
response := make([]*v1.Machine, len(machines))
|
||||
for index, machine := range machines {
|
||||
response[index] = machine.toProto()
|
||||
response[index] = machine.Proto()
|
||||
}
|
||||
|
||||
return &v1.ListMachinesResponse{Machines: response}, nil
|
||||
@ -336,9 +337,8 @@ func (api headscaleV1APIServer) ListMachines(
|
||||
|
||||
response := make([]*v1.Machine, len(machines))
|
||||
for index, machine := range machines {
|
||||
m := machine.toProto()
|
||||
validTags, invalidTags := getTags(
|
||||
api.h.aclPolicy,
|
||||
m := machine.Proto()
|
||||
validTags, invalidTags := api.h.ACLPolicy.GetTagsOfMachine(
|
||||
machine,
|
||||
api.h.cfg.OIDC.StripEmaildomain,
|
||||
)
|
||||
@ -364,7 +364,7 @@ func (api headscaleV1APIServer) MoveMachine(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &v1.MoveMachineResponse{Machine: machine.toProto()}, nil
|
||||
return &v1.MoveMachineResponse{Machine: machine.Proto()}, nil
|
||||
}
|
||||
|
||||
func (api headscaleV1APIServer) GetRoutes(
|
||||
@ -377,7 +377,7 @@ func (api headscaleV1APIServer) GetRoutes(
|
||||
}
|
||||
|
||||
return &v1.GetRoutesResponse{
|
||||
Routes: Routes(routes).toProto(),
|
||||
Routes: types.Routes(routes).Proto(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -420,7 +420,7 @@ func (api headscaleV1APIServer) GetMachineRoutes(
|
||||
}
|
||||
|
||||
return &v1.GetMachineRoutesResponse{
|
||||
Routes: Routes(routes).toProto(),
|
||||
Routes: types.Routes(routes).Proto(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -459,7 +459,7 @@ func (api headscaleV1APIServer) ExpireApiKey(
|
||||
ctx context.Context,
|
||||
request *v1.ExpireApiKeyRequest,
|
||||
) (*v1.ExpireApiKeyResponse, error) {
|
||||
var apiKey *APIKey
|
||||
var apiKey *types.APIKey
|
||||
var err error
|
||||
|
||||
apiKey, err = api.h.db.GetAPIKey(request.Prefix)
|
||||
@ -486,7 +486,7 @@ func (api headscaleV1APIServer) ListApiKeys(
|
||||
|
||||
response := make([]*v1.ApiKey, len(apiKeys))
|
||||
for index, key := range apiKeys {
|
||||
response[index] = key.toProto()
|
||||
response[index] = key.Proto()
|
||||
}
|
||||
|
||||
return &v1.ListApiKeysResponse{ApiKeys: response}, nil
|
||||
@ -524,7 +524,7 @@ func (api headscaleV1APIServer) DebugCreateMachine(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newMachine := Machine{
|
||||
newMachine := types.Machine{
|
||||
MachineKey: request.GetKey(),
|
||||
Hostname: request.GetName(),
|
||||
GivenName: givenName,
|
||||
@ -534,7 +534,7 @@ func (api headscaleV1APIServer) DebugCreateMachine(
|
||||
LastSeen: &time.Time{},
|
||||
LastSuccessfulUpdate: &time.Time{},
|
||||
|
||||
HostInfo: HostInfo(hostinfo),
|
||||
HostInfo: types.HostInfo(hostinfo),
|
||||
}
|
||||
|
||||
nodeKey := key.NodePublic{}
|
||||
@ -549,7 +549,7 @@ func (api headscaleV1APIServer) DebugCreateMachine(
|
||||
registerCacheExpiration,
|
||||
)
|
||||
|
||||
return &v1.DebugCreateMachineResponse{Machine: newMachine.toProto()}, nil
|
||||
return &v1.DebugCreateMachineResponse{Machine: newMachine.Proto()}, nil
|
||||
}
|
||||
|
||||
func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,142 +0,0 @@
|
||||
package hscontrol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// This is borrowed from, and updated to use IPSet
|
||||
// https://github.com/tailscale/tailscale/blob/71029cea2ddf82007b80f465b256d027eab0f02d/wgengine/filter/tailcfg.go#L97-L162
|
||||
// TODO(kradalby): contribute upstream and make public.
|
||||
var (
|
||||
zeroIP4 = netip.AddrFrom4([4]byte{})
|
||||
zeroIP6 = netip.AddrFrom16([16]byte{})
|
||||
)
|
||||
|
||||
// parseIPSet parses arg as one:
|
||||
//
|
||||
// - an IP address (IPv4 or IPv6)
|
||||
// - the string "*" to match everything (both IPv4 & IPv6)
|
||||
// - a CIDR (e.g. "192.168.0.0/16")
|
||||
// - a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800")
|
||||
//
|
||||
// bits, if non-nil, is the legacy SrcBits CIDR length to make a IP
|
||||
// address (without a slash) treated as a CIDR of *bits length.
|
||||
// nolint
|
||||
func parseIPSet(arg string, bits *int) (*netipx.IPSet, error) {
|
||||
var ipSet netipx.IPSetBuilder
|
||||
if arg == "*" {
|
||||
ipSet.AddPrefix(netip.PrefixFrom(zeroIP4, 0))
|
||||
ipSet.AddPrefix(netip.PrefixFrom(zeroIP6, 0))
|
||||
|
||||
return ipSet.IPSet()
|
||||
}
|
||||
if strings.Contains(arg, "/") {
|
||||
pfx, err := netip.ParsePrefix(arg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if pfx != pfx.Masked() {
|
||||
return nil, fmt.Errorf("%v contains non-network bits set", pfx)
|
||||
}
|
||||
|
||||
ipSet.AddPrefix(pfx)
|
||||
|
||||
return ipSet.IPSet()
|
||||
}
|
||||
if strings.Count(arg, "-") == 1 {
|
||||
ip1s, ip2s, _ := strings.Cut(arg, "-")
|
||||
|
||||
ip1, err := netip.ParseAddr(ip1s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ip2, err := netip.ParseAddr(ip2s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r := netipx.IPRangeFrom(ip1, ip2)
|
||||
if !r.IsValid() {
|
||||
return nil, fmt.Errorf("invalid IP range %q", arg)
|
||||
}
|
||||
|
||||
for _, prefix := range r.Prefixes() {
|
||||
ipSet.AddPrefix(prefix)
|
||||
}
|
||||
|
||||
return ipSet.IPSet()
|
||||
}
|
||||
ip, err := netip.ParseAddr(arg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid IP address %q", arg)
|
||||
}
|
||||
bits8 := uint8(ip.BitLen())
|
||||
if bits != nil {
|
||||
if *bits < 0 || *bits > int(bits8) {
|
||||
return nil, fmt.Errorf("invalid CIDR size %d for IP %q", *bits, arg)
|
||||
}
|
||||
bits8 = uint8(*bits)
|
||||
}
|
||||
|
||||
ipSet.AddPrefix(netip.PrefixFrom(ip, int(bits8)))
|
||||
|
||||
return ipSet.IPSet()
|
||||
}
|
||||
|
||||
type Match struct {
|
||||
Srcs *netipx.IPSet
|
||||
Dests *netipx.IPSet
|
||||
}
|
||||
|
||||
func MatchFromFilterRule(rule tailcfg.FilterRule) Match {
|
||||
srcs := new(netipx.IPSetBuilder)
|
||||
dests := new(netipx.IPSetBuilder)
|
||||
|
||||
for _, srcIP := range rule.SrcIPs {
|
||||
set, _ := parseIPSet(srcIP, nil)
|
||||
|
||||
srcs.AddSet(set)
|
||||
}
|
||||
|
||||
for _, dest := range rule.DstPorts {
|
||||
set, _ := parseIPSet(dest.IP, nil)
|
||||
|
||||
dests.AddSet(set)
|
||||
}
|
||||
|
||||
srcsSet, _ := srcs.IPSet()
|
||||
destsSet, _ := dests.IPSet()
|
||||
|
||||
match := Match{
|
||||
Srcs: srcsSet,
|
||||
Dests: destsSet,
|
||||
}
|
||||
|
||||
return match
|
||||
}
|
||||
|
||||
func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool {
|
||||
for _, ip := range ips {
|
||||
if m.Srcs.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Match) DestsContainsIP(ips []netip.Addr) bool {
|
||||
for _, ip := range ips {
|
||||
if m.Dests.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
@ -14,6 +14,8 @@ import (
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/juanfont/headscale/hscontrol/db"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/oauth2"
|
||||
@ -638,7 +640,7 @@ func getUserName(
|
||||
claims *IDTokenClaims,
|
||||
stripEmaildomain bool,
|
||||
) (string, error) {
|
||||
userName, err := NormalizeToFQDNRules(
|
||||
userName, err := util.NormalizeToFQDNRules(
|
||||
claims.Email,
|
||||
stripEmaildomain,
|
||||
)
|
||||
@ -663,9 +665,9 @@ func getUserName(
|
||||
func (h *Headscale) findOrCreateNewUserForOIDCCallback(
|
||||
writer http.ResponseWriter,
|
||||
userName string,
|
||||
) (*User, error) {
|
||||
) (*types.User, error) {
|
||||
user, err := h.db.GetUser(userName)
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
if errors.Is(err, db.ErrUserNotFound) {
|
||||
user, err = h.db.CreateUser(userName)
|
||||
|
||||
if err != nil {
|
||||
@ -709,7 +711,7 @@ func (h *Headscale) findOrCreateNewUserForOIDCCallback(
|
||||
|
||||
func (h *Headscale) registerMachineForOIDCCallback(
|
||||
writer http.ResponseWriter,
|
||||
user *User,
|
||||
user *types.User,
|
||||
nodeKey *key.NodePublic,
|
||||
expiry time.Time,
|
||||
) error {
|
||||
@ -719,7 +721,7 @@ func (h *Headscale) registerMachineForOIDCCallback(
|
||||
nodeKey.String(),
|
||||
user.Name,
|
||||
&expiry,
|
||||
RegisterMethodOIDC,
|
||||
util.RegisterMethodOIDC,
|
||||
); err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
|
@ -1,4 +1,4 @@
|
||||
package hscontrol
|
||||
package policy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/tailscale/hujson"
|
||||
@ -22,12 +23,12 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
errEmptyPolicy = errors.New("empty policy")
|
||||
errInvalidAction = errors.New("invalid action")
|
||||
errInvalidGroup = errors.New("invalid group")
|
||||
errInvalidTag = errors.New("invalid tag")
|
||||
errInvalidPortFormat = errors.New("invalid port format")
|
||||
errWildcardIsNeeded = errors.New("wildcard as port is required for the protocol")
|
||||
ErrEmptyPolicy = errors.New("empty policy")
|
||||
ErrInvalidAction = errors.New("invalid action")
|
||||
ErrInvalidGroup = errors.New("invalid group")
|
||||
ErrInvalidTag = errors.New("invalid tag")
|
||||
ErrInvalidPortFormat = errors.New("invalid port format")
|
||||
ErrWildcardIsNeeded = errors.New("wildcard as port is required for the protocol")
|
||||
)
|
||||
|
||||
const (
|
||||
@ -56,7 +57,7 @@ const (
|
||||
var featureEnableSSH = envknob.RegisterBool("HEADSCALE_EXPERIMENTAL_FEATURE_SSH")
|
||||
|
||||
// LoadACLPolicyFromPath loads the ACL policy from the specify path, and generates the ACL rules.
|
||||
func (h *Headscale) LoadACLPolicyFromPath(path string) error {
|
||||
func LoadACLPolicyFromPath(path string) (*ACLPolicy, error) {
|
||||
log.Debug().
|
||||
Str("func", "LoadACLPolicy").
|
||||
Str("path", path).
|
||||
@ -64,13 +65,13 @@ func (h *Headscale) LoadACLPolicyFromPath(path string) error {
|
||||
|
||||
policyFile, err := os.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
defer policyFile.Close()
|
||||
|
||||
policyBytes, err := io.ReadAll(policyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
@ -80,90 +81,90 @@ func (h *Headscale) LoadACLPolicyFromPath(path string) error {
|
||||
|
||||
switch filepath.Ext(path) {
|
||||
case ".yml", ".yaml":
|
||||
return h.LoadACLPolicyFromBytes(policyBytes, "yaml")
|
||||
return LoadACLPolicyFromBytes(policyBytes, "yaml")
|
||||
}
|
||||
|
||||
return h.LoadACLPolicyFromBytes(policyBytes, "hujson")
|
||||
return LoadACLPolicyFromBytes(policyBytes, "hujson")
|
||||
}
|
||||
|
||||
func (h *Headscale) LoadACLPolicyFromBytes(acl []byte, format string) error {
|
||||
func LoadACLPolicyFromBytes(acl []byte, format string) (*ACLPolicy, error) {
|
||||
var policy ACLPolicy
|
||||
switch format {
|
||||
case "yaml":
|
||||
err := yaml.Unmarshal(acl, &policy)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
default:
|
||||
ast, err := hujson.Parse(acl)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ast.Standardize()
|
||||
acl = ast.Pack()
|
||||
err = json.Unmarshal(acl, &policy)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if policy.IsZero() {
|
||||
return errEmptyPolicy
|
||||
return nil, ErrEmptyPolicy
|
||||
}
|
||||
|
||||
h.aclPolicy = &policy
|
||||
|
||||
return h.UpdateACLRules()
|
||||
return &policy, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) UpdateACLRules() error {
|
||||
machines, err := h.db.ListMachines()
|
||||
if err != nil {
|
||||
return err
|
||||
// TODO(kradalby): This needs to be replace with something that generates
|
||||
// the rules as needed and not stores it on the global object, rules are
|
||||
// per node and that should be taken into account.
|
||||
func GenerateFilterRules(
|
||||
policy *ACLPolicy,
|
||||
machines types.Machines,
|
||||
stripEmailDomain bool,
|
||||
) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) {
|
||||
if policy == nil {
|
||||
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, ErrEmptyPolicy
|
||||
}
|
||||
|
||||
if h.aclPolicy == nil {
|
||||
return errEmptyPolicy
|
||||
}
|
||||
|
||||
rules, err := h.aclPolicy.generateFilterRules(machines, h.cfg.OIDC.StripEmaildomain)
|
||||
rules, err := policy.generateFilterRules(machines, stripEmailDomain)
|
||||
if err != nil {
|
||||
return err
|
||||
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
||||
}
|
||||
|
||||
log.Trace().Interface("ACL", rules).Msg("ACL rules generated")
|
||||
h.aclRules = rules
|
||||
|
||||
var sshPolicy *tailcfg.SSHPolicy
|
||||
if featureEnableSSH() {
|
||||
sshRules, err := h.generateSSHRules()
|
||||
sshRules, err := generateSSHRules(policy, machines, stripEmailDomain)
|
||||
if err != nil {
|
||||
return err
|
||||
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
||||
}
|
||||
log.Trace().Interface("SSH", sshRules).Msg("SSH rules generated")
|
||||
if h.sshPolicy == nil {
|
||||
h.sshPolicy = &tailcfg.SSHPolicy{}
|
||||
if sshPolicy == nil {
|
||||
sshPolicy = &tailcfg.SSHPolicy{}
|
||||
}
|
||||
h.sshPolicy.Rules = sshRules
|
||||
} else if h.aclPolicy != nil && len(h.aclPolicy.SSHs) > 0 {
|
||||
sshPolicy.Rules = sshRules
|
||||
} else if policy != nil && len(policy.SSHs) > 0 {
|
||||
log.Info().Msg("SSH ACLs has been defined, but HEADSCALE_EXPERIMENTAL_FEATURE_SSH is not enabled, this is a unstable feature, check docs before activating")
|
||||
}
|
||||
|
||||
return nil
|
||||
return rules, sshPolicy, nil
|
||||
}
|
||||
|
||||
// generateFilterRules takes a set of machines and an ACLPolicy and generates a
|
||||
// set of Tailscale compatible FilterRules used to allow traffic on clients.
|
||||
func (pol *ACLPolicy) generateFilterRules(
|
||||
machines []Machine,
|
||||
machines types.Machines,
|
||||
stripEmailDomain bool,
|
||||
) ([]tailcfg.FilterRule, error) {
|
||||
rules := []tailcfg.FilterRule{}
|
||||
|
||||
for index, acl := range pol.ACLs {
|
||||
if acl.Action != "accept" {
|
||||
return nil, errInvalidAction
|
||||
return nil, ErrInvalidAction
|
||||
}
|
||||
|
||||
srcIPs := []string{}
|
||||
@ -219,16 +220,15 @@ func (pol *ACLPolicy) generateFilterRules(
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
|
||||
func generateSSHRules(
|
||||
policy *ACLPolicy,
|
||||
machines types.Machines,
|
||||
stripEmailDomain bool,
|
||||
) ([]*tailcfg.SSHRule, error) {
|
||||
rules := []*tailcfg.SSHRule{}
|
||||
|
||||
if h.aclPolicy == nil {
|
||||
return nil, errEmptyPolicy
|
||||
}
|
||||
|
||||
machines, err := h.db.ListMachines()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if policy == nil {
|
||||
return nil, ErrEmptyPolicy
|
||||
}
|
||||
|
||||
acceptAction := tailcfg.SSHAction{
|
||||
@ -251,7 +251,7 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
|
||||
AllowLocalPortForwarding: false,
|
||||
}
|
||||
|
||||
for index, sshACL := range h.aclPolicy.SSHs {
|
||||
for index, sshACL := range policy.SSHs {
|
||||
action := rejectAction
|
||||
switch sshACL.Action {
|
||||
case "accept":
|
||||
@ -266,9 +266,9 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
|
||||
}
|
||||
default:
|
||||
log.Error().
|
||||
Msgf("Error parsing SSH %d, unknown action '%s'", index, sshACL.Action)
|
||||
Msgf("Error parsing SSH %d, unknown action '%s', skipping", index, sshACL.Action)
|
||||
|
||||
return nil, err
|
||||
continue
|
||||
}
|
||||
|
||||
principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources))
|
||||
@ -278,7 +278,7 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
|
||||
Any: true,
|
||||
})
|
||||
} else if isGroup(rawSrc) {
|
||||
users, err := h.aclPolicy.getUsersInGroup(rawSrc, h.cfg.OIDC.StripEmaildomain)
|
||||
users, err := policy.getUsersInGroup(rawSrc, stripEmailDomain)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Msgf("Error parsing SSH %d, Source %d", index, innerIndex)
|
||||
@ -292,10 +292,10 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
|
||||
})
|
||||
}
|
||||
} else {
|
||||
expandedSrcs, err := h.aclPolicy.expandAlias(
|
||||
expandedSrcs, err := policy.ExpandAlias(
|
||||
machines,
|
||||
rawSrc,
|
||||
h.cfg.OIDC.StripEmaildomain,
|
||||
stripEmailDomain,
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
@ -346,10 +346,10 @@ func sshCheckAction(duration string) (*tailcfg.SSHAction, error) {
|
||||
// with the given src alias.
|
||||
func (pol *ACLPolicy) getIPsFromSource(
|
||||
src string,
|
||||
machines []Machine,
|
||||
machines types.Machines,
|
||||
stripEmaildomain bool,
|
||||
) ([]string, error) {
|
||||
ipSet, err := pol.expandAlias(machines, src, stripEmaildomain)
|
||||
ipSet, err := pol.ExpandAlias(machines, src, stripEmaildomain)
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
@ -367,7 +367,7 @@ func (pol *ACLPolicy) getIPsFromSource(
|
||||
// which are associated with the dest alias.
|
||||
func (pol *ACLPolicy) getNetPortRangeFromDestination(
|
||||
dest string,
|
||||
machines []Machine,
|
||||
machines types.Machines,
|
||||
needsWildcard bool,
|
||||
stripEmaildomain bool,
|
||||
) ([]tailcfg.NetPortRange, error) {
|
||||
@ -390,7 +390,7 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
|
||||
return nil, fmt.Errorf(
|
||||
"failed to parse destination, tokens %v: %w",
|
||||
tokens,
|
||||
errInvalidPortFormat,
|
||||
ErrInvalidPortFormat,
|
||||
)
|
||||
} else {
|
||||
tokens = []string{maybeIPv6Str, port}
|
||||
@ -414,7 +414,7 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
|
||||
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
|
||||
}
|
||||
|
||||
expanded, err := pol.expandAlias(
|
||||
expanded, err := pol.ExpandAlias(
|
||||
machines,
|
||||
alias,
|
||||
stripEmaildomain,
|
||||
@ -499,13 +499,13 @@ func parseProtocol(protocol string) ([]int, bool, error) {
|
||||
// - an ip
|
||||
// - a cidr
|
||||
// and transform these in IPAddresses.
|
||||
func (pol *ACLPolicy) expandAlias(
|
||||
machines Machines,
|
||||
func (pol *ACLPolicy) ExpandAlias(
|
||||
machines types.Machines,
|
||||
alias string,
|
||||
stripEmailDomain bool,
|
||||
) (*netipx.IPSet, error) {
|
||||
if isWildcard(alias) {
|
||||
return parseIPSet("*", nil)
|
||||
return util.ParseIPSet("*", nil)
|
||||
}
|
||||
|
||||
build := netipx.IPSetBuilder{}
|
||||
@ -532,9 +532,9 @@ func (pol *ACLPolicy) expandAlias(
|
||||
// if alias is an host
|
||||
// Note, this is recursive.
|
||||
if h, ok := pol.Hosts[alias]; ok {
|
||||
log.Trace().Str("host", h.String()).Msg("expandAlias got hosts entry")
|
||||
log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry")
|
||||
|
||||
return pol.expandAlias(machines, h.String(), stripEmailDomain)
|
||||
return pol.ExpandAlias(machines, h.String(), stripEmailDomain)
|
||||
}
|
||||
|
||||
// if alias is an IP
|
||||
@ -557,11 +557,11 @@ func (pol *ACLPolicy) expandAlias(
|
||||
// we assume in this function that we only have nodes from 1 user.
|
||||
func excludeCorrectlyTaggedNodes(
|
||||
aclPolicy *ACLPolicy,
|
||||
nodes []Machine,
|
||||
nodes types.Machines,
|
||||
user string,
|
||||
stripEmailDomain bool,
|
||||
) []Machine {
|
||||
out := []Machine{}
|
||||
) types.Machines {
|
||||
out := types.Machines{}
|
||||
tags := []string{}
|
||||
for tag := range aclPolicy.TagOwners {
|
||||
owners, _ := getTagOwners(aclPolicy, user, stripEmailDomain)
|
||||
@ -601,7 +601,7 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err
|
||||
}
|
||||
|
||||
if needsWildcard {
|
||||
return nil, errWildcardIsNeeded
|
||||
return nil, ErrWildcardIsNeeded
|
||||
}
|
||||
|
||||
ports := []tailcfg.PortRange{}
|
||||
@ -634,15 +634,15 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err
|
||||
})
|
||||
|
||||
default:
|
||||
return nil, errInvalidPortFormat
|
||||
return nil, ErrInvalidPortFormat
|
||||
}
|
||||
}
|
||||
|
||||
return &ports, nil
|
||||
}
|
||||
|
||||
func filterMachinesByUser(machines []Machine, user string) []Machine {
|
||||
out := []Machine{}
|
||||
func filterMachinesByUser(machines types.Machines, user string) types.Machines {
|
||||
out := types.Machines{}
|
||||
for _, machine := range machines {
|
||||
if machine.User.Name == user {
|
||||
out = append(out, machine)
|
||||
@ -664,7 +664,7 @@ func getTagOwners(
|
||||
if !ok {
|
||||
return []string{}, fmt.Errorf(
|
||||
"%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners",
|
||||
errInvalidTag,
|
||||
ErrInvalidTag,
|
||||
tag,
|
||||
)
|
||||
}
|
||||
@ -696,22 +696,22 @@ func (pol *ACLPolicy) getUsersInGroup(
|
||||
return []string{}, fmt.Errorf(
|
||||
"group %v isn't registered. %w",
|
||||
group,
|
||||
errInvalidGroup,
|
||||
ErrInvalidGroup,
|
||||
)
|
||||
}
|
||||
for _, group := range aclGroups {
|
||||
if isGroup(group) {
|
||||
return []string{}, fmt.Errorf(
|
||||
"%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups",
|
||||
errInvalidGroup,
|
||||
ErrInvalidGroup,
|
||||
)
|
||||
}
|
||||
grp, err := NormalizeToFQDNRules(group, stripEmailDomain)
|
||||
grp, err := util.NormalizeToFQDNRules(group, stripEmailDomain)
|
||||
if err != nil {
|
||||
return []string{}, fmt.Errorf(
|
||||
"failed to normalize group %q, err: %w",
|
||||
group,
|
||||
errInvalidGroup,
|
||||
ErrInvalidGroup,
|
||||
)
|
||||
}
|
||||
users = append(users, grp)
|
||||
@ -722,7 +722,7 @@ func (pol *ACLPolicy) getUsersInGroup(
|
||||
|
||||
func (pol *ACLPolicy) getIPsFromGroup(
|
||||
group string,
|
||||
machines Machines,
|
||||
machines types.Machines,
|
||||
stripEmailDomain bool,
|
||||
) (*netipx.IPSet, error) {
|
||||
build := netipx.IPSetBuilder{}
|
||||
@ -743,7 +743,7 @@ func (pol *ACLPolicy) getIPsFromGroup(
|
||||
|
||||
func (pol *ACLPolicy) getIPsFromTag(
|
||||
alias string,
|
||||
machines Machines,
|
||||
machines types.Machines,
|
||||
stripEmailDomain bool,
|
||||
) (*netipx.IPSet, error) {
|
||||
build := netipx.IPSetBuilder{}
|
||||
@ -758,12 +758,12 @@ func (pol *ACLPolicy) getIPsFromTag(
|
||||
// find tag owners
|
||||
owners, err := getTagOwners(pol, alias, stripEmailDomain)
|
||||
if err != nil {
|
||||
if errors.Is(err, errInvalidTag) {
|
||||
if errors.Is(err, ErrInvalidTag) {
|
||||
ipSet, _ := build.IPSet()
|
||||
if len(ipSet.Prefixes()) == 0 {
|
||||
return ipSet, fmt.Errorf(
|
||||
"%w. %v isn't owned by a TagOwner and no forced tags are defined",
|
||||
errInvalidTag,
|
||||
ErrInvalidTag,
|
||||
alias,
|
||||
)
|
||||
}
|
||||
@ -790,7 +790,7 @@ func (pol *ACLPolicy) getIPsFromTag(
|
||||
|
||||
func (pol *ACLPolicy) getIPsForUser(
|
||||
user string,
|
||||
machines Machines,
|
||||
machines types.Machines,
|
||||
stripEmailDomain bool,
|
||||
) (*netipx.IPSet, error) {
|
||||
build := netipx.IPSetBuilder{}
|
||||
@ -812,9 +812,9 @@ func (pol *ACLPolicy) getIPsForUser(
|
||||
|
||||
func (pol *ACLPolicy) getIPsFromSingleIP(
|
||||
ip netip.Addr,
|
||||
machines Machines,
|
||||
machines types.Machines,
|
||||
) (*netipx.IPSet, error) {
|
||||
log.Trace().Str("ip", ip.String()).Msg("expandAlias got ip")
|
||||
log.Trace().Str("ip", ip.String()).Msg("ExpandAlias got ip")
|
||||
|
||||
matches := machines.FilterByIP(ip)
|
||||
|
||||
@ -830,7 +830,7 @@ func (pol *ACLPolicy) getIPsFromSingleIP(
|
||||
|
||||
func (pol *ACLPolicy) getIPsFromIPPrefix(
|
||||
prefix netip.Prefix,
|
||||
machines Machines,
|
||||
machines types.Machines,
|
||||
) (*netipx.IPSet, error) {
|
||||
log.Trace().Str("prefix", prefix.String()).Msg("expandAlias got prefix")
|
||||
build := netipx.IPSetBuilder{}
|
||||
@ -862,3 +862,65 @@ func isGroup(str string) bool {
|
||||
func isTag(str string) bool {
|
||||
return strings.HasPrefix(str, "tag:")
|
||||
}
|
||||
|
||||
// getTags will return the tags of the current machine.
|
||||
// Invalid tags are tags added by a user on a node, and that user doesn't have authority to add this tag.
|
||||
// Valid tags are tags added by a user that is allowed in the ACL policy to add this tag.
|
||||
func (pol *ACLPolicy) GetTagsOfMachine(
|
||||
machine types.Machine,
|
||||
stripEmailDomain bool,
|
||||
) ([]string, []string) {
|
||||
validTags := make([]string, 0)
|
||||
invalidTags := make([]string, 0)
|
||||
|
||||
validTagMap := make(map[string]bool)
|
||||
invalidTagMap := make(map[string]bool)
|
||||
for _, tag := range machine.HostInfo.RequestTags {
|
||||
owners, err := getTagOwners(pol, tag, stripEmailDomain)
|
||||
if errors.Is(err, ErrInvalidTag) {
|
||||
invalidTagMap[tag] = true
|
||||
|
||||
continue
|
||||
}
|
||||
var found bool
|
||||
for _, owner := range owners {
|
||||
if machine.User.Name == owner {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if found {
|
||||
validTagMap[tag] = true
|
||||
} else {
|
||||
invalidTagMap[tag] = true
|
||||
}
|
||||
}
|
||||
for tag := range invalidTagMap {
|
||||
invalidTags = append(invalidTags, tag)
|
||||
}
|
||||
for tag := range validTagMap {
|
||||
validTags = append(validTags, tag)
|
||||
}
|
||||
|
||||
return validTags, invalidTags
|
||||
}
|
||||
|
||||
// FilterMachinesByACL returns the list of peers authorized to be accessed from a given machine.
|
||||
func FilterMachinesByACL(
|
||||
machine *types.Machine,
|
||||
machines types.Machines,
|
||||
filter []tailcfg.FilterRule,
|
||||
) types.Machines {
|
||||
result := types.Machines{}
|
||||
|
||||
for index, peer := range machines {
|
||||
if peer.ID == machine.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
if machine.CanAccess(filter, &machines[index]) || peer.CanAccess(filter, machine) {
|
||||
result = append(result, peer)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,4 @@
|
||||
package hscontrol
|
||||
package policy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
61
hscontrol/policy/matcher/matcher.go
Normal file
61
hscontrol/policy/matcher/matcher.go
Normal file
@ -0,0 +1,61 @@
|
||||
package matcher
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
type Match struct {
|
||||
Srcs *netipx.IPSet
|
||||
Dests *netipx.IPSet
|
||||
}
|
||||
|
||||
func MatchFromFilterRule(rule tailcfg.FilterRule) Match {
|
||||
srcs := new(netipx.IPSetBuilder)
|
||||
dests := new(netipx.IPSetBuilder)
|
||||
|
||||
for _, srcIP := range rule.SrcIPs {
|
||||
set, _ := util.ParseIPSet(srcIP, nil)
|
||||
|
||||
srcs.AddSet(set)
|
||||
}
|
||||
|
||||
for _, dest := range rule.DstPorts {
|
||||
set, _ := util.ParseIPSet(dest.IP, nil)
|
||||
|
||||
dests.AddSet(set)
|
||||
}
|
||||
|
||||
srcsSet, _ := srcs.IPSet()
|
||||
destsSet, _ := dests.IPSet()
|
||||
|
||||
match := Match{
|
||||
Srcs: srcsSet,
|
||||
Dests: destsSet,
|
||||
}
|
||||
|
||||
return match
|
||||
}
|
||||
|
||||
func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool {
|
||||
for _, ip := range ips {
|
||||
if m.Srcs.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Match) DestsContainsIP(ips []netip.Addr) bool {
|
||||
for _, ip := range ips {
|
||||
if m.Dests.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
1
hscontrol/policy/matcher/matcher_test.go
Normal file
1
hscontrol/policy/matcher/matcher_test.go
Normal file
@ -0,0 +1 @@
|
||||
package matcher
|
@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
@ -171,7 +172,7 @@ func (h *Headscale) handleRegisterCommon(
|
||||
// that we rely on a method that calls back some how (OpenID or CLI)
|
||||
// We create the machine and then keep it around until a callback
|
||||
// happens
|
||||
newMachine := Machine{
|
||||
newMachine := types.Machine{
|
||||
MachineKey: util.MachinePublicKeyStripPrefix(machineKey),
|
||||
Hostname: registerRequest.Hostinfo.Hostname,
|
||||
GivenName: givenName,
|
||||
@ -214,8 +215,7 @@ func (h *Headscale) handleRegisterCommon(
|
||||
[]byte(util.MachinePublicKeyEnsurePrefix(machine.MachineKey)),
|
||||
)
|
||||
if err != nil || storedMachineKey.IsZero() {
|
||||
machine.MachineKey = util.MachinePublicKeyStripPrefix(machineKey)
|
||||
if err := h.db.db.Save(&machine).Error; err != nil {
|
||||
if err := h.db.MachineSetMachineKey(machine, machineKey); err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("func", "RegistrationHandler").
|
||||
@ -244,7 +244,7 @@ func (h *Headscale) handleRegisterCommon(
|
||||
|
||||
// If machine is not expired, and it is register, we have a already accepted this machine,
|
||||
// let it proceed with a valid registration
|
||||
if !machine.isExpired() {
|
||||
if !machine.IsExpired() {
|
||||
h.handleMachineValidRegistrationCommon(writer, *machine, machineKey, isNoise)
|
||||
|
||||
return
|
||||
@ -253,7 +253,7 @@ func (h *Headscale) handleRegisterCommon(
|
||||
|
||||
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
|
||||
if machine.NodeKey == util.NodePublicKeyStripPrefix(registerRequest.OldNodeKey) &&
|
||||
!machine.isExpired() {
|
||||
!machine.IsExpired() {
|
||||
h.handleMachineRefreshKeyCommon(
|
||||
writer,
|
||||
registerRequest,
|
||||
@ -312,7 +312,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||
Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname)
|
||||
resp := tailcfg.RegisterResponse{}
|
||||
|
||||
pak, err := h.db.checkKeyValidity(registerRequest.Auth.AuthKey)
|
||||
pak, err := h.db.ValidatePreAuthKey(registerRequest.Auth.AuthKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
@ -333,7 +333,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
||||
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name).
|
||||
machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
|
||||
Inc()
|
||||
|
||||
return
|
||||
@ -358,10 +358,10 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||
Msg("Failed authentication via AuthKey")
|
||||
|
||||
if pak != nil {
|
||||
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name).
|
||||
machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
|
||||
Inc()
|
||||
} else {
|
||||
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", "unknown").Inc()
|
||||
machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", "unknown").Inc()
|
||||
}
|
||||
|
||||
return
|
||||
@ -401,10 +401,10 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||
return
|
||||
}
|
||||
|
||||
aclTags := pak.toProto().AclTags
|
||||
aclTags := pak.Proto().AclTags
|
||||
if len(aclTags) > 0 {
|
||||
// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login
|
||||
err = h.db.SetTags(machine, aclTags, h.UpdateACLRules)
|
||||
err = h.db.SetTags(machine, aclTags)
|
||||
|
||||
if err != nil {
|
||||
log.Error().
|
||||
@ -433,17 +433,17 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||
return
|
||||
}
|
||||
|
||||
machineToRegister := Machine{
|
||||
machineToRegister := types.Machine{
|
||||
Hostname: registerRequest.Hostinfo.Hostname,
|
||||
GivenName: givenName,
|
||||
UserID: pak.User.ID,
|
||||
MachineKey: util.MachinePublicKeyStripPrefix(machineKey),
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
Expiry: ®isterRequest.Expiry,
|
||||
NodeKey: nodeKey,
|
||||
LastSeen: &now,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
ForcedTags: pak.toProto().AclTags,
|
||||
ForcedTags: pak.Proto().AclTags,
|
||||
}
|
||||
|
||||
machine, err = h.db.RegisterMachine(
|
||||
@ -455,7 +455,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||
Bool("noise", isNoise).
|
||||
Err(err).
|
||||
Msg("could not register machine")
|
||||
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name).
|
||||
machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
|
||||
Inc()
|
||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
||||
|
||||
@ -470,7 +470,7 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||
Bool("noise", isNoise).
|
||||
Err(err).
|
||||
Msg("Failed to use pre-auth key")
|
||||
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name).
|
||||
machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
|
||||
Inc()
|
||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
||||
|
||||
@ -478,10 +478,10 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||
}
|
||||
|
||||
resp.MachineAuthorized = true
|
||||
resp.User = *pak.User.toTailscaleUser()
|
||||
resp.User = *pak.User.TailscaleUser()
|
||||
// Provide LoginName when registering with pre-auth key
|
||||
// Otherwise it will need to exec `tailscale up` twice to fetch the *LoginName*
|
||||
resp.Login = *pak.User.toTailscaleLogin()
|
||||
resp.Login = *pak.User.TailscaleLogin()
|
||||
|
||||
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
||||
if err != nil {
|
||||
@ -492,13 +492,13 @@ func (h *Headscale) handleAuthKeyCommon(
|
||||
Str("machine", registerRequest.Hostinfo.Hostname).
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name).
|
||||
machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
|
||||
Inc()
|
||||
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.User.Name).
|
||||
machineRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "success", pak.User.Name).
|
||||
Inc()
|
||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
@ -581,7 +581,7 @@ func (h *Headscale) handleNewMachineCommon(
|
||||
|
||||
func (h *Headscale) handleMachineLogOutCommon(
|
||||
writer http.ResponseWriter,
|
||||
machine Machine,
|
||||
machine types.Machine,
|
||||
machineKey key.MachinePublic,
|
||||
isNoise bool,
|
||||
) {
|
||||
@ -608,7 +608,7 @@ func (h *Headscale) handleMachineLogOutCommon(
|
||||
resp.AuthURL = ""
|
||||
resp.MachineAuthorized = false
|
||||
resp.NodeKeyExpired = true
|
||||
resp.User = *machine.User.toTailscaleUser()
|
||||
resp.User = *machine.User.TailscaleUser()
|
||||
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
@ -634,7 +634,7 @@ func (h *Headscale) handleMachineLogOutCommon(
|
||||
return
|
||||
}
|
||||
|
||||
if machine.isEphemeral() {
|
||||
if machine.IsEphemeral() {
|
||||
err = h.db.HardDeleteMachine(&machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
@ -655,7 +655,7 @@ func (h *Headscale) handleMachineLogOutCommon(
|
||||
|
||||
func (h *Headscale) handleMachineValidRegistrationCommon(
|
||||
writer http.ResponseWriter,
|
||||
machine Machine,
|
||||
machine types.Machine,
|
||||
machineKey key.MachinePublic,
|
||||
isNoise bool,
|
||||
) {
|
||||
@ -670,8 +670,8 @@ func (h *Headscale) handleMachineValidRegistrationCommon(
|
||||
|
||||
resp.AuthURL = ""
|
||||
resp.MachineAuthorized = true
|
||||
resp.User = *machine.User.toTailscaleUser()
|
||||
resp.Login = *machine.User.toTailscaleLogin()
|
||||
resp.User = *machine.User.TailscaleUser()
|
||||
resp.Login = *machine.User.TailscaleLogin()
|
||||
|
||||
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
||||
if err != nil {
|
||||
@ -710,7 +710,7 @@ func (h *Headscale) handleMachineValidRegistrationCommon(
|
||||
func (h *Headscale) handleMachineRefreshKeyCommon(
|
||||
writer http.ResponseWriter,
|
||||
registerRequest tailcfg.RegisterRequest,
|
||||
machine Machine,
|
||||
machine types.Machine,
|
||||
machineKey key.MachinePublic,
|
||||
isNoise bool,
|
||||
) {
|
||||
@ -721,9 +721,9 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Msg("We have the OldNodeKey in the database. This is a key refresh")
|
||||
machine.NodeKey = util.NodePublicKeyStripPrefix(registerRequest.NodeKey)
|
||||
|
||||
if err := h.db.db.Save(&machine).Error; err != nil {
|
||||
err := h.db.MachineSetNodeKey(&machine, registerRequest.NodeKey)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
@ -734,7 +734,7 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
|
||||
}
|
||||
|
||||
resp.AuthURL = ""
|
||||
resp.User = *machine.User.toTailscaleUser()
|
||||
resp.User = *machine.User.TailscaleUser()
|
||||
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
@ -770,7 +770,7 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
|
||||
func (h *Headscale) handleMachineExpiredOrLoggedOutCommon(
|
||||
writer http.ResponseWriter,
|
||||
registerRequest tailcfg.RegisterRequest,
|
||||
machine Machine,
|
||||
machine types.Machine,
|
||||
machineKey key.MachinePublic,
|
||||
isNoise bool,
|
||||
) {
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/tailcfg"
|
||||
@ -24,16 +25,16 @@ const machineNameContextKey = contextKey("machineName")
|
||||
func (h *Headscale) handlePollCommon(
|
||||
writer http.ResponseWriter,
|
||||
ctx context.Context,
|
||||
machine *Machine,
|
||||
machine *types.Machine,
|
||||
mapRequest tailcfg.MapRequest,
|
||||
isNoise bool,
|
||||
) {
|
||||
machine.Hostname = mapRequest.Hostinfo.Hostname
|
||||
machine.HostInfo = HostInfo(*mapRequest.Hostinfo)
|
||||
machine.HostInfo = types.HostInfo(*mapRequest.Hostinfo)
|
||||
machine.DiscoKey = util.DiscoPublicKeyStripPrefix(mapRequest.DiscoKey)
|
||||
now := time.Now().UTC()
|
||||
|
||||
err := h.db.processMachineRoutes(machine)
|
||||
err := h.db.ProcessMachineRoutes(machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
@ -43,18 +44,13 @@ func (h *Headscale) handlePollCommon(
|
||||
}
|
||||
|
||||
// update ACLRules with peer informations (to update server tags if necessary)
|
||||
if h.aclPolicy != nil {
|
||||
err := h.UpdateACLRules()
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Bool("noise", isNoise).
|
||||
Str("machine", machine.Hostname).
|
||||
Err(err)
|
||||
}
|
||||
if h.ACLPolicy != nil {
|
||||
// TODO(kradalby): Since this is not blocking, I might have introduced a bug here.
|
||||
// It will be resolved later as we change up the policy stuff.
|
||||
h.policyUpdateChan <- struct{}{}
|
||||
|
||||
// update routes with peer information
|
||||
err = h.db.EnableAutoApprovedRoutes(h.aclPolicy, machine)
|
||||
err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, machine)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
@ -78,19 +74,17 @@ func (h *Headscale) handlePollCommon(
|
||||
machine.LastSeen = &now
|
||||
}
|
||||
|
||||
if err := h.db.db.Updates(machine).Error; err != nil {
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMap").
|
||||
Bool("noise", isNoise).
|
||||
Str("node_key", machine.NodeKey).
|
||||
Str("machine", machine.Hostname).
|
||||
Err(err).
|
||||
Msg("Failed to persist/update machine in the database")
|
||||
http.Error(writer, "", http.StatusInternalServerError)
|
||||
if err := h.db.MachineSave(machine); err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMap").
|
||||
Bool("noise", isNoise).
|
||||
Str("node_key", machine.NodeKey).
|
||||
Str("machine", machine.Hostname).
|
||||
Err(err).
|
||||
Msg("Failed to persist/update machine in the database")
|
||||
http.Error(writer, "", http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
mapResp, err := h.getMapResponseData(mapRequest, machine, isNoise)
|
||||
@ -244,7 +238,7 @@ func (h *Headscale) handlePollCommon(
|
||||
func (h *Headscale) pollNetMapStream(
|
||||
writer http.ResponseWriter,
|
||||
ctxReq context.Context,
|
||||
machine *Machine,
|
||||
machine *types.Machine,
|
||||
mapRequest tailcfg.MapRequest,
|
||||
pollDataChan chan []byte,
|
||||
keepAliveChan chan []byte,
|
||||
@ -457,7 +451,7 @@ func (h *Headscale) pollNetMapStream(
|
||||
updateRequestsReceivedOnChannel.WithLabelValues(machine.User.Name, machine.Hostname).
|
||||
Inc()
|
||||
|
||||
if h.db.isOutdated(machine, h.getLastStateChange()) {
|
||||
if h.db.IsOutdated(machine, h.getLastStateChange()) {
|
||||
var lastUpdate time.Time
|
||||
if machine.LastSuccessfulUpdate != nil {
|
||||
lastUpdate = *machine.LastSuccessfulUpdate
|
||||
@ -626,7 +620,7 @@ func (h *Headscale) scheduledPollWorker(
|
||||
updateChan chan struct{},
|
||||
keepAliveChan chan []byte,
|
||||
mapRequest tailcfg.MapRequest,
|
||||
machine *Machine,
|
||||
machine *types.Machine,
|
||||
isNoise bool,
|
||||
) {
|
||||
keepAliveTicker := time.NewTicker(keepAliveInterval)
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/rs/zerolog/log"
|
||||
@ -15,7 +16,7 @@ import (
|
||||
|
||||
func (h *Headscale) getMapResponseData(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
machine *Machine,
|
||||
machine *types.Machine,
|
||||
isNoise bool,
|
||||
) ([]byte, error) {
|
||||
mapResponse, err := h.generateMapResponse(mapRequest, machine)
|
||||
@ -43,7 +44,7 @@ func (h *Headscale) getMapResponseData(
|
||||
|
||||
func (h *Headscale) getMapKeepAliveResponseData(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
machine *Machine,
|
||||
machine *types.Machine,
|
||||
isNoise bool,
|
||||
) ([]byte, error) {
|
||||
keepAliveResponse := tailcfg.MapResponse{
|
||||
|
@ -18,7 +18,7 @@ type Suite struct{}
|
||||
|
||||
var (
|
||||
tmpDir string
|
||||
app Headscale
|
||||
app *Headscale
|
||||
)
|
||||
|
||||
func (s *Suite) SetUpTest(c *check.C) {
|
||||
@ -34,11 +34,15 @@ func (s *Suite) ResetDB(c *check.C) {
|
||||
os.RemoveAll(tmpDir)
|
||||
}
|
||||
var err error
|
||||
tmpDir, err = os.MkdirTemp("", "autoygg-client-test")
|
||||
tmpDir, err = os.MkdirTemp("", "autoygg-client-test2")
|
||||
if err != nil {
|
||||
c.Fatal(err)
|
||||
}
|
||||
cfg := Config{
|
||||
PrivateKeyPath: tmpDir + "/private.key",
|
||||
NoisePrivateKeyPath: tmpDir + "/noise_private.key",
|
||||
DBtype: "sqlite3",
|
||||
DBpath: tmpDir + "/headscale_test.db",
|
||||
IPPrefixes: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.27.0.0/23"),
|
||||
},
|
||||
@ -47,29 +51,8 @@ func (s *Suite) ResetDB(c *check.C) {
|
||||
},
|
||||
}
|
||||
|
||||
// TODO(kradalby): make this use NewHeadscale properly so it doesnt drift
|
||||
app = Headscale{
|
||||
cfg: &cfg,
|
||||
dbType: "sqlite3",
|
||||
dbString: tmpDir + "/headscale_test.db",
|
||||
|
||||
stateUpdateChan: make(chan struct{}),
|
||||
cancelStateUpdateChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
go app.watchStateChannel()
|
||||
|
||||
db, err := NewHeadscaleDatabase(
|
||||
app.dbType,
|
||||
app.dbString,
|
||||
cfg.OIDC.StripEmaildomain,
|
||||
false,
|
||||
app.stateUpdateChan,
|
||||
cfg.IPPrefixes,
|
||||
"",
|
||||
)
|
||||
app, err = NewHeadscale(&cfg)
|
||||
if err != nil {
|
||||
c.Fatal(err)
|
||||
}
|
||||
app.db = db
|
||||
}
|
41
hscontrol/types/api_key.go
Normal file
41
hscontrol/types/api_key.go
Normal file
@ -0,0 +1,41 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
// APIKey describes the datamodel for API keys used to remotely authenticate with
|
||||
// headscale.
|
||||
type APIKey struct {
|
||||
ID uint64 `gorm:"primary_key"`
|
||||
Prefix string `gorm:"uniqueIndex"`
|
||||
Hash []byte
|
||||
|
||||
CreatedAt *time.Time
|
||||
Expiration *time.Time
|
||||
LastSeen *time.Time
|
||||
}
|
||||
|
||||
func (key *APIKey) Proto() *v1.ApiKey {
|
||||
protoKey := v1.ApiKey{
|
||||
Id: key.ID,
|
||||
Prefix: key.Prefix,
|
||||
}
|
||||
|
||||
if key.Expiration != nil {
|
||||
protoKey.Expiration = timestamppb.New(*key.Expiration)
|
||||
}
|
||||
|
||||
if key.CreatedAt != nil {
|
||||
protoKey.CreatedAt = timestamppb.New(*key.CreatedAt)
|
||||
}
|
||||
|
||||
if key.LastSeen != nil {
|
||||
protoKey.LastSeen = timestamppb.New(*key.LastSeen)
|
||||
}
|
||||
|
||||
return &protoKey
|
||||
}
|
108
hscontrol/types/common.go
Normal file
108
hscontrol/types/common.go
Normal file
@ -0,0 +1,108 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
var ErrCannotParsePrefix = errors.New("cannot parse prefix")
|
||||
|
||||
// This is a "wrapper" type around tailscales
|
||||
// Hostinfo to allow us to add database "serialization"
|
||||
// methods. This allows us to use a typed values throughout
|
||||
// the code and not have to marshal/unmarshal and error
|
||||
// check all over the code.
|
||||
type HostInfo tailcfg.Hostinfo
|
||||
|
||||
func (hi *HostInfo) Scan(destination interface{}) error {
|
||||
switch value := destination.(type) {
|
||||
case []byte:
|
||||
return json.Unmarshal(value, hi)
|
||||
|
||||
case string:
|
||||
return json.Unmarshal([]byte(value), hi)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
|
||||
}
|
||||
}
|
||||
|
||||
// Value return json value, implement driver.Valuer interface.
|
||||
func (hi HostInfo) Value() (driver.Value, error) {
|
||||
bytes, err := json.Marshal(hi)
|
||||
|
||||
return string(bytes), err
|
||||
}
|
||||
|
||||
type IPPrefix netip.Prefix
|
||||
|
||||
func (i *IPPrefix) Scan(destination interface{}) error {
|
||||
switch value := destination.(type) {
|
||||
case string:
|
||||
prefix, err := netip.ParsePrefix(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*i = IPPrefix(prefix)
|
||||
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected data type %T", ErrCannotParsePrefix, destination)
|
||||
}
|
||||
}
|
||||
|
||||
// Value return json value, implement driver.Valuer interface.
|
||||
func (i IPPrefix) Value() (driver.Value, error) {
|
||||
prefixStr := netip.Prefix(i).String()
|
||||
|
||||
return prefixStr, nil
|
||||
}
|
||||
|
||||
type IPPrefixes []netip.Prefix
|
||||
|
||||
func (i *IPPrefixes) Scan(destination interface{}) error {
|
||||
switch value := destination.(type) {
|
||||
case []byte:
|
||||
return json.Unmarshal(value, i)
|
||||
|
||||
case string:
|
||||
return json.Unmarshal([]byte(value), i)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
|
||||
}
|
||||
}
|
||||
|
||||
// Value return json value, implement driver.Valuer interface.
|
||||
func (i IPPrefixes) Value() (driver.Value, error) {
|
||||
bytes, err := json.Marshal(i)
|
||||
|
||||
return string(bytes), err
|
||||
}
|
||||
|
||||
type StringList []string
|
||||
|
||||
func (i *StringList) Scan(destination interface{}) error {
|
||||
switch value := destination.(type) {
|
||||
case []byte:
|
||||
return json.Unmarshal(value, i)
|
||||
|
||||
case string:
|
||||
return json.Unmarshal([]byte(value), i)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
|
||||
}
|
||||
}
|
||||
|
||||
// Value return json value, implement driver.Valuer interface.
|
||||
func (i StringList) Value() (driver.Value, error) {
|
||||
bytes, err := json.Marshal(i)
|
||||
|
||||
return string(bytes), err
|
||||
}
|
254
hscontrol/types/machine.go
Normal file
254
hscontrol/types/machine.go
Normal file
@ -0,0 +1,254 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"go4.org/netipx"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
const (
|
||||
// TODO(kradalby): Move out of here when we got circdeps under control.
|
||||
keepAliveInterval = 60 * time.Second
|
||||
)
|
||||
|
||||
var ErrMachineAddressesInvalid = errors.New("failed to parse machine addresses")
|
||||
|
||||
// Machine is a Headscale client.
|
||||
type Machine struct {
|
||||
ID uint64 `gorm:"primary_key"`
|
||||
MachineKey string `gorm:"type:varchar(64);unique_index"`
|
||||
NodeKey string
|
||||
DiscoKey string
|
||||
IPAddresses MachineAddresses
|
||||
|
||||
// Hostname represents the name given by the Tailscale
|
||||
// client during registration
|
||||
Hostname string
|
||||
|
||||
// Givenname represents either:
|
||||
// a DNS normalized version of Hostname
|
||||
// a valid name set by the User
|
||||
//
|
||||
// GivenName is the name used in all DNS related
|
||||
// parts of headscale.
|
||||
GivenName string `gorm:"type:varchar(63);unique_index"`
|
||||
UserID uint
|
||||
User User `gorm:"foreignKey:UserID"`
|
||||
|
||||
RegisterMethod string
|
||||
|
||||
ForcedTags StringList
|
||||
|
||||
// TODO(kradalby): This seems like irrelevant information?
|
||||
AuthKeyID uint
|
||||
AuthKey *PreAuthKey
|
||||
|
||||
LastSeen *time.Time
|
||||
LastSuccessfulUpdate *time.Time
|
||||
Expiry *time.Time
|
||||
|
||||
HostInfo HostInfo
|
||||
Endpoints StringList
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt *time.Time
|
||||
}
|
||||
|
||||
type (
|
||||
Machines []Machine
|
||||
MachinesP []*Machine
|
||||
)
|
||||
|
||||
type MachineAddresses []netip.Addr
|
||||
|
||||
func (ma MachineAddresses) ToStringSlice() []string {
|
||||
strSlice := make([]string, 0, len(ma))
|
||||
for _, addr := range ma {
|
||||
strSlice = append(strSlice, addr.String())
|
||||
}
|
||||
|
||||
return strSlice
|
||||
}
|
||||
|
||||
// AppendToIPSet adds the individual ips in MachineAddresses to a
|
||||
// given netipx.IPSetBuilder.
|
||||
func (ma MachineAddresses) AppendToIPSet(build *netipx.IPSetBuilder) {
|
||||
for _, ip := range ma {
|
||||
build.Add(ip)
|
||||
}
|
||||
}
|
||||
|
||||
func (ma *MachineAddresses) Scan(destination interface{}) error {
|
||||
switch value := destination.(type) {
|
||||
case string:
|
||||
addresses := strings.Split(value, ",")
|
||||
*ma = (*ma)[:0]
|
||||
for _, addr := range addresses {
|
||||
if len(addr) < 1 {
|
||||
continue
|
||||
}
|
||||
parsed, err := netip.ParseAddr(addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*ma = append(*ma, parsed)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
default:
|
||||
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
|
||||
}
|
||||
}
|
||||
|
||||
// Value return json value, implement driver.Valuer interface.
|
||||
func (ma MachineAddresses) Value() (driver.Value, error) {
|
||||
addresses := strings.Join(ma.ToStringSlice(), ",")
|
||||
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
// IsExpired returns whether the machine registration has expired.
|
||||
func (machine Machine) IsExpired() bool {
|
||||
// If Expiry is not set, the client has not indicated that
|
||||
// it wants an expiry time, it is therefor considered
|
||||
// to mean "not expired"
|
||||
if machine.Expiry == nil || machine.Expiry.IsZero() {
|
||||
return false
|
||||
}
|
||||
|
||||
return time.Now().UTC().After(*machine.Expiry)
|
||||
}
|
||||
|
||||
// IsOnline returns if the machine is connected to Headscale.
|
||||
// This is really a naive implementation, as we don't really see
|
||||
// if there is a working connection between the client and the server.
|
||||
func (machine *Machine) IsOnline() bool {
|
||||
if machine.LastSeen == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if machine.IsExpired() {
|
||||
return false
|
||||
}
|
||||
|
||||
return machine.LastSeen.After(time.Now().Add(-keepAliveInterval))
|
||||
}
|
||||
|
||||
// IsEphemeral returns if the machine is registered as an Ephemeral node.
|
||||
// https://tailscale.com/kb/1111/ephemeral-nodes/
|
||||
func (machine *Machine) IsEphemeral() bool {
|
||||
return machine.AuthKey != nil && machine.AuthKey.Ephemeral
|
||||
}
|
||||
|
||||
func (machine *Machine) CanAccess(filter []tailcfg.FilterRule, machine2 *Machine) bool {
|
||||
for _, rule := range filter {
|
||||
// TODO(kradalby): Cache or pregen this
|
||||
matcher := matcher.MatchFromFilterRule(rule)
|
||||
|
||||
if !matcher.SrcsContainsIPs([]netip.Addr(machine.IPAddresses)) {
|
||||
continue
|
||||
}
|
||||
|
||||
if matcher.DestsContainsIP([]netip.Addr(machine2.IPAddresses)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (machines Machines) FilterByIP(ip netip.Addr) Machines {
|
||||
found := make(Machines, 0)
|
||||
|
||||
for _, machine := range machines {
|
||||
for _, mIP := range machine.IPAddresses {
|
||||
if ip == mIP {
|
||||
found = append(found, machine)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return found
|
||||
}
|
||||
|
||||
func (machine *Machine) Proto() *v1.Machine {
|
||||
machineProto := &v1.Machine{
|
||||
Id: machine.ID,
|
||||
MachineKey: machine.MachineKey,
|
||||
|
||||
NodeKey: machine.NodeKey,
|
||||
DiscoKey: machine.DiscoKey,
|
||||
IpAddresses: machine.IPAddresses.ToStringSlice(),
|
||||
Name: machine.Hostname,
|
||||
GivenName: machine.GivenName,
|
||||
User: machine.User.Proto(),
|
||||
ForcedTags: machine.ForcedTags,
|
||||
Online: machine.IsOnline(),
|
||||
|
||||
// TODO(kradalby): Implement register method enum converter
|
||||
// RegisterMethod: ,
|
||||
|
||||
CreatedAt: timestamppb.New(machine.CreatedAt),
|
||||
}
|
||||
|
||||
if machine.AuthKey != nil {
|
||||
machineProto.PreAuthKey = machine.AuthKey.Proto()
|
||||
}
|
||||
|
||||
if machine.LastSeen != nil {
|
||||
machineProto.LastSeen = timestamppb.New(*machine.LastSeen)
|
||||
}
|
||||
|
||||
if machine.LastSuccessfulUpdate != nil {
|
||||
machineProto.LastSuccessfulUpdate = timestamppb.New(
|
||||
*machine.LastSuccessfulUpdate,
|
||||
)
|
||||
}
|
||||
|
||||
if machine.Expiry != nil {
|
||||
machineProto.Expiry = timestamppb.New(*machine.Expiry)
|
||||
}
|
||||
|
||||
return machineProto
|
||||
}
|
||||
|
||||
// GetHostInfo returns a Hostinfo struct for the machine.
|
||||
func (machine *Machine) GetHostInfo() tailcfg.Hostinfo {
|
||||
return tailcfg.Hostinfo(machine.HostInfo)
|
||||
}
|
||||
|
||||
func (machine Machine) String() string {
|
||||
return machine.Hostname
|
||||
}
|
||||
|
||||
func (machines Machines) String() string {
|
||||
temp := make([]string, len(machines))
|
||||
|
||||
for index, machine := range machines {
|
||||
temp[index] = machine.Hostname
|
||||
}
|
||||
|
||||
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
|
||||
}
|
||||
|
||||
// TODO(kradalby): Remove when we have generics...
|
||||
func (machines MachinesP) String() string {
|
||||
temp := make([]string, len(machines))
|
||||
|
||||
for index, machine := range machines {
|
||||
temp[index] = machine.Hostname
|
||||
}
|
||||
|
||||
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
|
||||
}
|
1
hscontrol/types/machine_test.go
Normal file
1
hscontrol/types/machine_test.go
Normal file
@ -0,0 +1 @@
|
||||
package types
|
58
hscontrol/types/preauth_key.go
Normal file
58
hscontrol/types/preauth_key.go
Normal file
@ -0,0 +1,58 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
// PreAuthKey describes a pre-authorization key usable in a particular user.
|
||||
type PreAuthKey struct {
|
||||
ID uint64 `gorm:"primary_key"`
|
||||
Key string
|
||||
UserID uint
|
||||
User User
|
||||
Reusable bool
|
||||
Ephemeral bool `gorm:"default:false"`
|
||||
Used bool `gorm:"default:false"`
|
||||
ACLTags []PreAuthKeyACLTag
|
||||
|
||||
CreatedAt *time.Time
|
||||
Expiration *time.Time
|
||||
}
|
||||
|
||||
// PreAuthKeyACLTag describes an autmatic tag applied to a node when registered with the associated PreAuthKey.
|
||||
type PreAuthKeyACLTag struct {
|
||||
ID uint64 `gorm:"primary_key"`
|
||||
PreAuthKeyID uint64
|
||||
Tag string
|
||||
}
|
||||
|
||||
func (key *PreAuthKey) Proto() *v1.PreAuthKey {
|
||||
protoKey := v1.PreAuthKey{
|
||||
User: key.User.Name,
|
||||
Id: strconv.FormatUint(key.ID, util.Base10),
|
||||
Key: key.Key,
|
||||
Ephemeral: key.Ephemeral,
|
||||
Reusable: key.Reusable,
|
||||
Used: key.Used,
|
||||
AclTags: make([]string, len(key.ACLTags)),
|
||||
}
|
||||
|
||||
if key.Expiration != nil {
|
||||
protoKey.Expiration = timestamppb.New(*key.Expiration)
|
||||
}
|
||||
|
||||
if key.CreatedAt != nil {
|
||||
protoKey.CreatedAt = timestamppb.New(*key.CreatedAt)
|
||||
}
|
||||
|
||||
for idx := range key.ACLTags {
|
||||
protoKey.AclTags[idx] = key.ACLTags[idx].Tag
|
||||
}
|
||||
|
||||
return &protoKey
|
||||
}
|
71
hscontrol/types/routes.go
Normal file
71
hscontrol/types/routes.go
Normal file
@ -0,0 +1,71 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0")
|
||||
ExitRouteV6 = netip.MustParsePrefix("::/0")
|
||||
)
|
||||
|
||||
type Route struct {
|
||||
gorm.Model
|
||||
|
||||
MachineID uint64
|
||||
Machine Machine
|
||||
Prefix IPPrefix
|
||||
|
||||
Advertised bool
|
||||
Enabled bool
|
||||
IsPrimary bool
|
||||
}
|
||||
|
||||
type Routes []Route
|
||||
|
||||
func (r *Route) String() string {
|
||||
return fmt.Sprintf("%s:%s", r.Machine, netip.Prefix(r.Prefix).String())
|
||||
}
|
||||
|
||||
func (r *Route) IsExitRoute() bool {
|
||||
return netip.Prefix(r.Prefix) == ExitRouteV4 || netip.Prefix(r.Prefix) == ExitRouteV6
|
||||
}
|
||||
|
||||
func (rs Routes) Prefixes() []netip.Prefix {
|
||||
prefixes := make([]netip.Prefix, len(rs))
|
||||
for i, r := range rs {
|
||||
prefixes[i] = netip.Prefix(r.Prefix)
|
||||
}
|
||||
|
||||
return prefixes
|
||||
}
|
||||
|
||||
func (rs Routes) Proto() []*v1.Route {
|
||||
protoRoutes := []*v1.Route{}
|
||||
|
||||
for _, route := range rs {
|
||||
protoRoute := v1.Route{
|
||||
Id: uint64(route.ID),
|
||||
Machine: route.Machine.Proto(),
|
||||
Prefix: netip.Prefix(route.Prefix).String(),
|
||||
Advertised: route.Advertised,
|
||||
Enabled: route.Enabled,
|
||||
IsPrimary: route.IsPrimary,
|
||||
CreatedAt: timestamppb.New(route.CreatedAt),
|
||||
UpdatedAt: timestamppb.New(route.UpdatedAt),
|
||||
}
|
||||
|
||||
if route.DeletedAt.Valid {
|
||||
protoRoute.DeletedAt = timestamppb.New(route.DeletedAt.Time)
|
||||
}
|
||||
|
||||
protoRoutes = append(protoRoutes, &protoRoute)
|
||||
}
|
||||
|
||||
return protoRoutes
|
||||
}
|
55
hscontrol/types/users.go
Normal file
55
hscontrol/types/users.go
Normal file
@ -0,0 +1,55 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// User is the way Headscale implements the concept of users in Tailscale
|
||||
//
|
||||
// At the end of the day, users in Tailscale are some kind of 'bubbles' or users
|
||||
// that contain our machines.
|
||||
type User struct {
|
||||
gorm.Model
|
||||
Name string `gorm:"unique"`
|
||||
}
|
||||
|
||||
func (n *User) TailscaleUser() *tailcfg.User {
|
||||
user := tailcfg.User{
|
||||
ID: tailcfg.UserID(n.ID),
|
||||
LoginName: n.Name,
|
||||
DisplayName: n.Name,
|
||||
ProfilePicURL: "",
|
||||
Domain: "headscale.net",
|
||||
Logins: []tailcfg.LoginID{},
|
||||
Created: time.Time{},
|
||||
}
|
||||
|
||||
return &user
|
||||
}
|
||||
|
||||
func (n *User) TailscaleLogin() *tailcfg.Login {
|
||||
login := tailcfg.Login{
|
||||
ID: tailcfg.LoginID(n.ID),
|
||||
LoginName: n.Name,
|
||||
DisplayName: n.Name,
|
||||
ProfilePicURL: "",
|
||||
Domain: "headscale.net",
|
||||
}
|
||||
|
||||
return &login
|
||||
}
|
||||
|
||||
func (n *User) Proto() *v1.User {
|
||||
return &v1.User{
|
||||
Id: strconv.FormatUint(uint64(n.ID), util.Base10),
|
||||
Name: n.Name,
|
||||
CreatedAt: timestamppb.New(n.CreatedAt),
|
||||
}
|
||||
}
|
@ -1,415 +0,0 @@
|
||||
package hscontrol
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"gopkg.in/check.v1"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
|
||||
user, err := app.db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(user.Name, check.Equals, "test")
|
||||
|
||||
users, err := app.db.ListUsers()
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(users), check.Equals, 1)
|
||||
|
||||
err = app.db.DestroyUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetUser("test")
|
||||
c.Assert(err, check.NotNil)
|
||||
}
|
||||
|
||||
func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
||||
err := app.db.DestroyUser("test")
|
||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||
|
||||
user, err := app.db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = app.db.DestroyUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
result := app.db.db.Preload("User").First(&pak, "key = ?", pak.Key)
|
||||
// destroying a user also deletes all associated preauthkeys
|
||||
c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound)
|
||||
|
||||
user, err = app.db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err = app.db.CreatePreAuthKey(user.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machine := Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
app.db.db.Save(&machine)
|
||||
|
||||
err = app.db.DestroyUser("test")
|
||||
c.Assert(err, check.Equals, ErrUserStillHasNodes)
|
||||
}
|
||||
|
||||
func (s *Suite) TestRenameUser(c *check.C) {
|
||||
userTest, err := app.db.CreateUser("test")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(userTest.Name, check.Equals, "test")
|
||||
|
||||
users, err := app.db.ListUsers()
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(users), check.Equals, 1)
|
||||
|
||||
err = app.db.RenameUser("test", "test-renamed")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetUser("test")
|
||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||
|
||||
_, err = app.db.GetUser("test-renamed")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
err = app.db.RenameUser("test-does-not-exit", "test")
|
||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||
|
||||
userTest2, err := app.db.CreateUser("test2")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(userTest2.Name, check.Equals, "test2")
|
||||
|
||||
err = app.db.RenameUser("test2", "test-renamed")
|
||||
c.Assert(err, check.Equals, ErrUserExists)
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
|
||||
userShared1, err := app.db.CreateUser("shared1")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
userShared2, err := app.db.CreateUser("shared2")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
userShared3, err := app.db.CreateUser("shared3")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
preAuthKeyShared1, err := app.db.CreatePreAuthKey(
|
||||
userShared1.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
preAuthKeyShared2, err := app.db.CreatePreAuthKey(
|
||||
userShared2.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
preAuthKeyShared3, err := app.db.CreatePreAuthKey(
|
||||
userShared3.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
preAuthKey2Shared1, err := app.db.CreatePreAuthKey(
|
||||
userShared1.Name,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = app.db.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
machineInShared1 := &Machine{
|
||||
ID: 1,
|
||||
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||
Hostname: "test_get_shared_nodes_1",
|
||||
UserID: userShared1.ID,
|
||||
User: *userShared1,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
|
||||
AuthKeyID: uint(preAuthKeyShared1.ID),
|
||||
}
|
||||
app.db.db.Save(machineInShared1)
|
||||
|
||||
_, err = app.db.GetMachine(userShared1.Name, machineInShared1.Hostname)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machineInShared2 := &Machine{
|
||||
ID: 2,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Hostname: "test_get_shared_nodes_2",
|
||||
UserID: userShared2.ID,
|
||||
User: *userShared2,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
|
||||
AuthKeyID: uint(preAuthKeyShared2.ID),
|
||||
}
|
||||
app.db.db.Save(machineInShared2)
|
||||
|
||||
_, err = app.db.GetMachine(userShared2.Name, machineInShared2.Hostname)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machineInShared3 := &Machine{
|
||||
ID: 3,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Hostname: "test_get_shared_nodes_3",
|
||||
UserID: userShared3.ID,
|
||||
User: *userShared3,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
|
||||
AuthKeyID: uint(preAuthKeyShared3.ID),
|
||||
}
|
||||
app.db.db.Save(machineInShared3)
|
||||
|
||||
_, err = app.db.GetMachine(userShared3.Name, machineInShared3.Hostname)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machine2InShared1 := &Machine{
|
||||
ID: 4,
|
||||
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||
Hostname: "test_get_shared_nodes_4",
|
||||
UserID: userShared1.ID,
|
||||
User: *userShared1,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
|
||||
AuthKeyID: uint(preAuthKey2Shared1.ID),
|
||||
}
|
||||
app.db.db.Save(machine2InShared1)
|
||||
|
||||
peersOfMachine1InShared1, err := app.db.getPeers(app.aclPolicy, app.aclRules, machineInShared1)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
userProfiles := app.db.getMapResponseUserProfiles(
|
||||
*machineInShared1,
|
||||
peersOfMachine1InShared1,
|
||||
)
|
||||
|
||||
c.Assert(len(userProfiles), check.Equals, 3)
|
||||
|
||||
found := false
|
||||
for _, userProfiles := range userProfiles {
|
||||
if userProfiles.DisplayName == userShared1.Name {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
c.Assert(found, check.Equals, true)
|
||||
|
||||
found = false
|
||||
for _, userProfile := range userProfiles {
|
||||
if userProfile.DisplayName == userShared2.Name {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
c.Assert(found, check.Equals, true)
|
||||
}
|
||||
|
||||
func TestNormalizeToFQDNRules(t *testing.T) {
|
||||
type args struct {
|
||||
name string
|
||||
stripEmailDomain bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "normalize simple name",
|
||||
args: args{
|
||||
name: "normalize-simple.name",
|
||||
stripEmailDomain: false,
|
||||
},
|
||||
want: "normalize-simple.name",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "normalize an email",
|
||||
args: args{
|
||||
name: "foo.bar@example.com",
|
||||
stripEmailDomain: false,
|
||||
},
|
||||
want: "foo.bar.example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "normalize an email domain should be removed",
|
||||
args: args{
|
||||
name: "foo.bar@example.com",
|
||||
stripEmailDomain: true,
|
||||
},
|
||||
want: "foo.bar",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "strip enabled no email passed as argument",
|
||||
args: args{
|
||||
name: "not-email-and-strip-enabled",
|
||||
stripEmailDomain: true,
|
||||
},
|
||||
want: "not-email-and-strip-enabled",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "normalize complex email",
|
||||
args: args{
|
||||
name: "foo.bar+complex-email@example.com",
|
||||
stripEmailDomain: false,
|
||||
},
|
||||
want: "foo.bar-complex-email.example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "user name with space",
|
||||
args: args{
|
||||
name: "name space",
|
||||
stripEmailDomain: false,
|
||||
},
|
||||
want: "name-space",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "user with quote",
|
||||
args: args{
|
||||
name: "Jamie's iPhone 5",
|
||||
stripEmailDomain: false,
|
||||
},
|
||||
want: "jamies-iphone-5",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NormalizeToFQDNRules(tt.args.name, tt.args.stripEmailDomain)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf(
|
||||
"NormalizeToFQDNRules() error = %v, wantErr %v",
|
||||
err,
|
||||
tt.wantErr,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("NormalizeToFQDNRules() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckForFQDNRules(t *testing.T) {
|
||||
type args struct {
|
||||
name string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid: user",
|
||||
args: args{name: "valid-user"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid: capitalized user",
|
||||
args: args{name: "Invalid-CapItaLIzed-user"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid: email as user",
|
||||
args: args{name: "foo.bar@example.com"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid: chars in user name",
|
||||
args: args{name: "super-user+name"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid: too long name for user",
|
||||
args: args{
|
||||
name: "super-long-useruseruser-name-that-should-be-a-little-more-than-63-chars",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := CheckForFQDNRules(tt.args.name); (err != nil) != tt.wantErr {
|
||||
t.Errorf("CheckForFQDNRules() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Suite) TestSetMachineUser(c *check.C) {
|
||||
oldUser, err := app.db.CreateUser("old")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
newUser, err := app.db.CreateUser("new")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := app.db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
machine := Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Hostname: "testmachine",
|
||||
UserID: oldUser.ID,
|
||||
RegisterMethod: RegisterMethodAuthKey,
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
app.db.db.Save(&machine)
|
||||
c.Assert(machine.UserID, check.Equals, oldUser.ID)
|
||||
|
||||
err = app.db.SetMachineUser(&machine, newUser.Name)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(machine.UserID, check.Equals, newUser.ID)
|
||||
c.Assert(machine.User.Name, check.Equals, newUser.Name)
|
||||
|
||||
err = app.db.SetMachineUser(&machine, "non-existing-user")
|
||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||
|
||||
err = app.db.SetMachineUser(&machine, newUser.Name)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(machine.UserID, check.Equals, newUser.ID)
|
||||
c.Assert(machine.User.Name, check.Equals, newUser.Name)
|
||||
}
|
@ -1,12 +1,94 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"go4.org/netipx"
|
||||
)
|
||||
|
||||
// This is borrowed from, and updated to use IPSet
|
||||
// https://github.com/tailscale/tailscale/blob/71029cea2ddf82007b80f465b256d027eab0f02d/wgengine/filter/tailcfg.go#L97-L162
|
||||
// TODO(kradalby): contribute upstream and make public.
|
||||
var (
|
||||
zeroIP4 = netip.AddrFrom4([4]byte{})
|
||||
zeroIP6 = netip.AddrFrom16([16]byte{})
|
||||
)
|
||||
|
||||
// parseIPSet parses arg as one:
|
||||
//
|
||||
// - an IP address (IPv4 or IPv6)
|
||||
// - the string "*" to match everything (both IPv4 & IPv6)
|
||||
// - a CIDR (e.g. "192.168.0.0/16")
|
||||
// - a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800")
|
||||
//
|
||||
// bits, if non-nil, is the legacy SrcBits CIDR length to make a IP
|
||||
// address (without a slash) treated as a CIDR of *bits length.
|
||||
// nolint
|
||||
func ParseIPSet(arg string, bits *int) (*netipx.IPSet, error) {
|
||||
var ipSet netipx.IPSetBuilder
|
||||
if arg == "*" {
|
||||
ipSet.AddPrefix(netip.PrefixFrom(zeroIP4, 0))
|
||||
ipSet.AddPrefix(netip.PrefixFrom(zeroIP6, 0))
|
||||
|
||||
return ipSet.IPSet()
|
||||
}
|
||||
if strings.Contains(arg, "/") {
|
||||
pfx, err := netip.ParsePrefix(arg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if pfx != pfx.Masked() {
|
||||
return nil, fmt.Errorf("%v contains non-network bits set", pfx)
|
||||
}
|
||||
|
||||
ipSet.AddPrefix(pfx)
|
||||
|
||||
return ipSet.IPSet()
|
||||
}
|
||||
if strings.Count(arg, "-") == 1 {
|
||||
ip1s, ip2s, _ := strings.Cut(arg, "-")
|
||||
|
||||
ip1, err := netip.ParseAddr(ip1s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ip2, err := netip.ParseAddr(ip2s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r := netipx.IPRangeFrom(ip1, ip2)
|
||||
if !r.IsValid() {
|
||||
return nil, fmt.Errorf("invalid IP range %q", arg)
|
||||
}
|
||||
|
||||
for _, prefix := range r.Prefixes() {
|
||||
ipSet.AddPrefix(prefix)
|
||||
}
|
||||
|
||||
return ipSet.IPSet()
|
||||
}
|
||||
ip, err := netip.ParseAddr(arg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid IP address %q", arg)
|
||||
}
|
||||
bits8 := uint8(ip.BitLen())
|
||||
if bits != nil {
|
||||
if *bits < 0 || *bits > int(bits8) {
|
||||
return nil, fmt.Errorf("invalid CIDR size %d for IP %q", *bits, arg)
|
||||
}
|
||||
bits8 = uint8(*bits)
|
||||
}
|
||||
|
||||
ipSet.AddPrefix(netip.PrefixFrom(ip, int(bits8)))
|
||||
|
||||
return ipSet.IPSet()
|
||||
}
|
||||
|
||||
func GetIPPrefixEndpoints(na netip.Prefix) (netip.Addr, netip.Addr) {
|
||||
var network, broadcast netip.Addr
|
||||
ipRange := netipx.RangeOfPrefix(na)
|
||||
|
@ -1,4 +1,4 @@
|
||||
package hscontrol
|
||||
package util
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
@ -105,7 +105,7 @@ func Test_parseIPSet(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := parseIPSet(tt.args.arg, tt.args.bits)
|
||||
got, err := ParseIPSet(tt.args.arg, tt.args.bits)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parseIPSet() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
7
hscontrol/util/const.go
Normal file
7
hscontrol/util/const.go
Normal file
@ -0,0 +1,7 @@
|
||||
package util
|
||||
|
||||
const (
|
||||
RegisterMethodAuthKey = "authkey"
|
||||
RegisterMethodOIDC = "oidc"
|
||||
RegisterMethodCLI = "cli"
|
||||
)
|
69
hscontrol/util/dns.go
Normal file
69
hscontrol/util/dns.go
Normal file
@ -0,0 +1,69 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// value related to RFC 1123 and 952.
|
||||
LabelHostnameLength = 63
|
||||
)
|
||||
|
||||
var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
|
||||
|
||||
var ErrInvalidUserName = errors.New("invalid user name")
|
||||
|
||||
// NormalizeToFQDNRules will replace forbidden chars in user
|
||||
// it can also return an error if the user doesn't respect RFC 952 and 1123.
|
||||
func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) {
|
||||
name = strings.ToLower(name)
|
||||
name = strings.ReplaceAll(name, "'", "")
|
||||
atIdx := strings.Index(name, "@")
|
||||
if stripEmailDomain && atIdx > 0 {
|
||||
name = name[:atIdx]
|
||||
} else {
|
||||
name = strings.ReplaceAll(name, "@", ".")
|
||||
}
|
||||
name = invalidCharsInUserRegex.ReplaceAllString(name, "-")
|
||||
|
||||
for _, elt := range strings.Split(name, ".") {
|
||||
if len(elt) > LabelHostnameLength {
|
||||
return "", fmt.Errorf(
|
||||
"label %v is more than 63 chars: %w",
|
||||
elt,
|
||||
ErrInvalidUserName,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func CheckForFQDNRules(name string) error {
|
||||
if len(name) > LabelHostnameLength {
|
||||
return fmt.Errorf(
|
||||
"DNS segment must not be over 63 chars. %v doesn't comply with this rule: %w",
|
||||
name,
|
||||
ErrInvalidUserName,
|
||||
)
|
||||
}
|
||||
if strings.ToLower(name) != name {
|
||||
return fmt.Errorf(
|
||||
"DNS segment should be lowercase. %v doesn't comply with this rule: %w",
|
||||
name,
|
||||
ErrInvalidUserName,
|
||||
)
|
||||
}
|
||||
if invalidCharsInUserRegex.MatchString(name) {
|
||||
return fmt.Errorf(
|
||||
"DNS segment should only be composed of lowercase ASCII letters numbers, hyphen and dots. %v doesn't comply with theses rules: %w",
|
||||
name,
|
||||
ErrInvalidUserName,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
143
hscontrol/util/dns_test.go
Normal file
143
hscontrol/util/dns_test.go
Normal file
@ -0,0 +1,143 @@
|
||||
package util
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeToFQDNRules(t *testing.T) {
|
||||
type args struct {
|
||||
name string
|
||||
stripEmailDomain bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "normalize simple name",
|
||||
args: args{
|
||||
name: "normalize-simple.name",
|
||||
stripEmailDomain: false,
|
||||
},
|
||||
want: "normalize-simple.name",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "normalize an email",
|
||||
args: args{
|
||||
name: "foo.bar@example.com",
|
||||
stripEmailDomain: false,
|
||||
},
|
||||
want: "foo.bar.example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "normalize an email domain should be removed",
|
||||
args: args{
|
||||
name: "foo.bar@example.com",
|
||||
stripEmailDomain: true,
|
||||
},
|
||||
want: "foo.bar",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "strip enabled no email passed as argument",
|
||||
args: args{
|
||||
name: "not-email-and-strip-enabled",
|
||||
stripEmailDomain: true,
|
||||
},
|
||||
want: "not-email-and-strip-enabled",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "normalize complex email",
|
||||
args: args{
|
||||
name: "foo.bar+complex-email@example.com",
|
||||
stripEmailDomain: false,
|
||||
},
|
||||
want: "foo.bar-complex-email.example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "user name with space",
|
||||
args: args{
|
||||
name: "name space",
|
||||
stripEmailDomain: false,
|
||||
},
|
||||
want: "name-space",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "user with quote",
|
||||
args: args{
|
||||
name: "Jamie's iPhone 5",
|
||||
stripEmailDomain: false,
|
||||
},
|
||||
want: "jamies-iphone-5",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NormalizeToFQDNRules(tt.args.name, tt.args.stripEmailDomain)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf(
|
||||
"NormalizeToFQDNRules() error = %v, wantErr %v",
|
||||
err,
|
||||
tt.wantErr,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("NormalizeToFQDNRules() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckForFQDNRules(t *testing.T) {
|
||||
type args struct {
|
||||
name string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid: user",
|
||||
args: args{name: "valid-user"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid: capitalized user",
|
||||
args: args{name: "Invalid-CapItaLIzed-user"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid: email as user",
|
||||
args: args{name: "foo.bar@example.com"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid: chars in user name",
|
||||
args: args{name: "super-user+name"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid: too long name for user",
|
||||
args: args{
|
||||
name: "super-long-useruseruser-name-that-should-be-a-little-more-than-63-chars",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := CheckForFQDNRules(tt.args.name); (err != nil) != tt.wantErr {
|
||||
t.Errorf("CheckForFQDNRules() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -6,7 +6,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -45,7 +45,7 @@ var veryLargeDestination = []string{
|
||||
"208.0.0.0/4:*",
|
||||
}
|
||||
|
||||
func aclScenario(t *testing.T, policy *hscontrol.ACLPolicy, clientsPerUser int) *Scenario {
|
||||
func aclScenario(t *testing.T, policy *policy.ACLPolicy, clientsPerUser int) *Scenario {
|
||||
t.Helper()
|
||||
scenario, err := NewScenario()
|
||||
assert.NoError(t, err)
|
||||
@ -92,7 +92,7 @@ func TestACLHostsInNetMapTable(t *testing.T) {
|
||||
// they can access minus one (them self).
|
||||
tests := map[string]struct {
|
||||
users map[string]int
|
||||
policy hscontrol.ACLPolicy
|
||||
policy policy.ACLPolicy
|
||||
want map[string]int
|
||||
}{
|
||||
// Test that when we have no ACL, each client netmap has
|
||||
@ -102,8 +102,8 @@ func TestACLHostsInNetMapTable(t *testing.T) {
|
||||
"user1": 2,
|
||||
"user2": 2,
|
||||
},
|
||||
policy: hscontrol.ACLPolicy{
|
||||
ACLs: []hscontrol.ACL{
|
||||
policy: policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
@ -123,8 +123,8 @@ func TestACLHostsInNetMapTable(t *testing.T) {
|
||||
"user1": 2,
|
||||
"user2": 2,
|
||||
},
|
||||
policy: hscontrol.ACLPolicy{
|
||||
ACLs: []hscontrol.ACL{
|
||||
policy: policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"user1"},
|
||||
@ -149,8 +149,8 @@ func TestACLHostsInNetMapTable(t *testing.T) {
|
||||
"user1": 2,
|
||||
"user2": 2,
|
||||
},
|
||||
policy: hscontrol.ACLPolicy{
|
||||
ACLs: []hscontrol.ACL{
|
||||
policy: policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"user1"},
|
||||
@ -186,8 +186,8 @@ func TestACLHostsInNetMapTable(t *testing.T) {
|
||||
"user1": 2,
|
||||
"user2": 2,
|
||||
},
|
||||
policy: hscontrol.ACLPolicy{
|
||||
ACLs: []hscontrol.ACL{
|
||||
policy: policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"user1"},
|
||||
@ -214,8 +214,8 @@ func TestACLHostsInNetMapTable(t *testing.T) {
|
||||
"user1": 2,
|
||||
"user2": 2,
|
||||
},
|
||||
policy: hscontrol.ACLPolicy{
|
||||
ACLs: []hscontrol.ACL{
|
||||
policy: policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"user1"},
|
||||
@ -282,8 +282,8 @@ func TestACLAllowUser80Dst(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
scenario := aclScenario(t,
|
||||
&hscontrol.ACLPolicy{
|
||||
ACLs: []hscontrol.ACL{
|
||||
&policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"user1"},
|
||||
@ -338,11 +338,11 @@ func TestACLDenyAllPort80(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
scenario := aclScenario(t,
|
||||
&hscontrol.ACLPolicy{
|
||||
&policy.ACLPolicy{
|
||||
Groups: map[string][]string{
|
||||
"group:integration-acl-test": {"user1", "user2"},
|
||||
},
|
||||
ACLs: []hscontrol.ACL{
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"group:integration-acl-test"},
|
||||
@ -387,8 +387,8 @@ func TestACLAllowUserDst(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
scenario := aclScenario(t,
|
||||
&hscontrol.ACLPolicy{
|
||||
ACLs: []hscontrol.ACL{
|
||||
&policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"user1"},
|
||||
@ -445,8 +445,8 @@ func TestACLAllowStarDst(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
scenario := aclScenario(t,
|
||||
&hscontrol.ACLPolicy{
|
||||
ACLs: []hscontrol.ACL{
|
||||
&policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"user1"},
|
||||
@ -504,11 +504,11 @@ func TestACLNamedHostsCanReachBySubnet(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
scenario := aclScenario(t,
|
||||
&hscontrol.ACLPolicy{
|
||||
Hosts: hscontrol.Hosts{
|
||||
&policy.ACLPolicy{
|
||||
Hosts: policy.Hosts{
|
||||
"all": netip.MustParsePrefix("100.64.0.0/24"),
|
||||
},
|
||||
ACLs: []hscontrol.ACL{
|
||||
ACLs: []policy.ACL{
|
||||
// Everyone can curl test3
|
||||
{
|
||||
Action: "accept",
|
||||
@ -603,16 +603,16 @@ func TestACLNamedHostsCanReach(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
tests := map[string]struct {
|
||||
policy hscontrol.ACLPolicy
|
||||
policy policy.ACLPolicy
|
||||
}{
|
||||
"ipv4": {
|
||||
policy: hscontrol.ACLPolicy{
|
||||
Hosts: hscontrol.Hosts{
|
||||
policy: policy.ACLPolicy{
|
||||
Hosts: policy.Hosts{
|
||||
"test1": netip.MustParsePrefix("100.64.0.1/32"),
|
||||
"test2": netip.MustParsePrefix("100.64.0.2/32"),
|
||||
"test3": netip.MustParsePrefix("100.64.0.3/32"),
|
||||
},
|
||||
ACLs: []hscontrol.ACL{
|
||||
ACLs: []policy.ACL{
|
||||
// Everyone can curl test3
|
||||
{
|
||||
Action: "accept",
|
||||
@ -629,13 +629,13 @@ func TestACLNamedHostsCanReach(t *testing.T) {
|
||||
},
|
||||
},
|
||||
"ipv6": {
|
||||
policy: hscontrol.ACLPolicy{
|
||||
Hosts: hscontrol.Hosts{
|
||||
policy: policy.ACLPolicy{
|
||||
Hosts: policy.Hosts{
|
||||
"test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
|
||||
"test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
|
||||
"test3": netip.MustParsePrefix("fd7a:115c:a1e0::3/128"),
|
||||
},
|
||||
ACLs: []hscontrol.ACL{
|
||||
ACLs: []policy.ACL{
|
||||
// Everyone can curl test3
|
||||
{
|
||||
Action: "accept",
|
||||
@ -854,11 +854,11 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
tests := map[string]struct {
|
||||
policy hscontrol.ACLPolicy
|
||||
policy policy.ACLPolicy
|
||||
}{
|
||||
"ipv4": {
|
||||
policy: hscontrol.ACLPolicy{
|
||||
ACLs: []hscontrol.ACL{
|
||||
policy: policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"100.64.0.1"},
|
||||
@ -868,8 +868,8 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
|
||||
},
|
||||
},
|
||||
"ipv6": {
|
||||
policy: hscontrol.ACLPolicy{
|
||||
ACLs: []hscontrol.ACL{
|
||||
policy: policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"fd7a:115c:a1e0::1"},
|
||||
@ -879,12 +879,12 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
|
||||
},
|
||||
},
|
||||
"hostv4cidr": {
|
||||
policy: hscontrol.ACLPolicy{
|
||||
Hosts: hscontrol.Hosts{
|
||||
policy: policy.ACLPolicy{
|
||||
Hosts: policy.Hosts{
|
||||
"test1": netip.MustParsePrefix("100.64.0.1/32"),
|
||||
"test2": netip.MustParsePrefix("100.64.0.2/32"),
|
||||
},
|
||||
ACLs: []hscontrol.ACL{
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"test1"},
|
||||
@ -894,12 +894,12 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
|
||||
},
|
||||
},
|
||||
"hostv6cidr": {
|
||||
policy: hscontrol.ACLPolicy{
|
||||
Hosts: hscontrol.Hosts{
|
||||
policy: policy.ACLPolicy{
|
||||
Hosts: policy.Hosts{
|
||||
"test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
|
||||
"test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
|
||||
},
|
||||
ACLs: []hscontrol.ACL{
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"test1"},
|
||||
@ -909,12 +909,12 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
|
||||
},
|
||||
},
|
||||
"group": {
|
||||
policy: hscontrol.ACLPolicy{
|
||||
policy: policy.ACLPolicy{
|
||||
Groups: map[string][]string{
|
||||
"group:one": {"user1"},
|
||||
"group:two": {"user2"},
|
||||
},
|
||||
ACLs: []hscontrol.ACL{
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"group:one"},
|
||||
|
@ -23,7 +23,7 @@ import (
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||
"github.com/juanfont/headscale/integration/integrationutil"
|
||||
@ -60,7 +60,7 @@ type HeadscaleInContainer struct {
|
||||
port int
|
||||
extraPorts []string
|
||||
hostPortBindings map[string][]string
|
||||
aclPolicy *hscontrol.ACLPolicy
|
||||
aclPolicy *policy.ACLPolicy
|
||||
env map[string]string
|
||||
tlsCert []byte
|
||||
tlsKey []byte
|
||||
@ -73,7 +73,7 @@ type Option = func(c *HeadscaleInContainer)
|
||||
|
||||
// WithACLPolicy adds a hscontrol.ACLPolicy policy to the
|
||||
// HeadscaleInContainer instance.
|
||||
func WithACLPolicy(acl *hscontrol.ACLPolicy) Option {
|
||||
func WithACLPolicy(acl *policy.ACLPolicy) Option {
|
||||
return func(hsic *HeadscaleInContainer) {
|
||||
// TODO(kradalby): Move somewhere appropriate
|
||||
hsic.env["HEADSCALE_ACL_POLICY_PATH"] = aclPolicyPath
|
||||
|
@ -6,7 +6,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -57,18 +57,18 @@ func TestSSHOneUserAllToAll(t *testing.T) {
|
||||
err = scenario.CreateHeadscaleEnv(spec,
|
||||
[]tsic.Option{tsic.WithSSH()},
|
||||
hsic.WithACLPolicy(
|
||||
&hscontrol.ACLPolicy{
|
||||
&policy.ACLPolicy{
|
||||
Groups: map[string][]string{
|
||||
"group:integration-test": {"user1"},
|
||||
},
|
||||
ACLs: []hscontrol.ACL{
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
Destinations: []string{"*:*"},
|
||||
},
|
||||
},
|
||||
SSHs: []hscontrol.SSH{
|
||||
SSHs: []policy.SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"group:integration-test"},
|
||||
@ -134,18 +134,18 @@ func TestSSHMultipleUsersAllToAll(t *testing.T) {
|
||||
err = scenario.CreateHeadscaleEnv(spec,
|
||||
[]tsic.Option{tsic.WithSSH()},
|
||||
hsic.WithACLPolicy(
|
||||
&hscontrol.ACLPolicy{
|
||||
&policy.ACLPolicy{
|
||||
Groups: map[string][]string{
|
||||
"group:integration-test": {"user1", "user2"},
|
||||
},
|
||||
ACLs: []hscontrol.ACL{
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
Destinations: []string{"*:*"},
|
||||
},
|
||||
},
|
||||
SSHs: []hscontrol.SSH{
|
||||
SSHs: []policy.SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"group:integration-test"},
|
||||
@ -216,18 +216,18 @@ func TestSSHNoSSHConfigured(t *testing.T) {
|
||||
err = scenario.CreateHeadscaleEnv(spec,
|
||||
[]tsic.Option{tsic.WithSSH()},
|
||||
hsic.WithACLPolicy(
|
||||
&hscontrol.ACLPolicy{
|
||||
&policy.ACLPolicy{
|
||||
Groups: map[string][]string{
|
||||
"group:integration-test": {"user1"},
|
||||
},
|
||||
ACLs: []hscontrol.ACL{
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
Destinations: []string{"*:*"},
|
||||
},
|
||||
},
|
||||
SSHs: []hscontrol.SSH{},
|
||||
SSHs: []policy.SSH{},
|
||||
},
|
||||
),
|
||||
hsic.WithTestName("sshnoneconfigured"),
|
||||
@ -286,18 +286,18 @@ func TestSSHIsBlockedInACL(t *testing.T) {
|
||||
err = scenario.CreateHeadscaleEnv(spec,
|
||||
[]tsic.Option{tsic.WithSSH()},
|
||||
hsic.WithACLPolicy(
|
||||
&hscontrol.ACLPolicy{
|
||||
&policy.ACLPolicy{
|
||||
Groups: map[string][]string{
|
||||
"group:integration-test": {"user1"},
|
||||
},
|
||||
ACLs: []hscontrol.ACL{
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
Destinations: []string{"*:80"},
|
||||
},
|
||||
},
|
||||
SSHs: []hscontrol.SSH{
|
||||
SSHs: []policy.SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"group:integration-test"},
|
||||
@ -364,19 +364,19 @@ func TestSSUserOnlyIsolation(t *testing.T) {
|
||||
err = scenario.CreateHeadscaleEnv(spec,
|
||||
[]tsic.Option{tsic.WithSSH()},
|
||||
hsic.WithACLPolicy(
|
||||
&hscontrol.ACLPolicy{
|
||||
&policy.ACLPolicy{
|
||||
Groups: map[string][]string{
|
||||
"group:ssh1": {"useracl1"},
|
||||
"group:ssh2": {"useracl2"},
|
||||
},
|
||||
ACLs: []hscontrol.ACL{
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
Destinations: []string{"*:*"},
|
||||
},
|
||||
},
|
||||
SSHs: []hscontrol.SSH{
|
||||
SSHs: []policy.SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"group:ssh1"},
|
||||
|
Loading…
x
Reference in New Issue
Block a user