mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-27 12:05:26 +00:00
Replace the timestamp based state system
This commit replaces the timestamp based state system with a new one that has update channels directly to the connected nodes. It will send an update to all listening clients via the polling mechanism. It introduces a new package notifier, which has a concurrency safe manager for all our channels to the connected nodes. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
056d3a81c5
commit
66ff1fcd40
108
hscontrol/app.go
108
hscontrol/app.go
@ -10,7 +10,6 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"sort"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -26,13 +25,13 @@ import (
|
|||||||
"github.com/juanfont/headscale/hscontrol/db"
|
"github.com/juanfont/headscale/hscontrol/db"
|
||||||
"github.com/juanfont/headscale/hscontrol/derp"
|
"github.com/juanfont/headscale/hscontrol/derp"
|
||||||
derpServer "github.com/juanfont/headscale/hscontrol/derp/server"
|
derpServer "github.com/juanfont/headscale/hscontrol/derp/server"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||||
"github.com/juanfont/headscale/hscontrol/policy"
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
zerolog "github.com/philip-bui/grpc-zerolog"
|
zerolog "github.com/philip-bui/grpc-zerolog"
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
"github.com/puzpuzpuz/xsync/v2"
|
|
||||||
zl "github.com/rs/zerolog"
|
zl "github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"golang.org/x/crypto/acme"
|
"golang.org/x/crypto/acme"
|
||||||
@ -84,7 +83,7 @@ type Headscale struct {
|
|||||||
|
|
||||||
ACLPolicy *policy.ACLPolicy
|
ACLPolicy *policy.ACLPolicy
|
||||||
|
|
||||||
lastStateChange *xsync.MapOf[string, time.Time]
|
nodeNotifier *notifier.Notifier
|
||||||
|
|
||||||
oidcProvider *oidc.Provider
|
oidcProvider *oidc.Provider
|
||||||
oauth2Config *oauth2.Config
|
oauth2Config *oauth2.Config
|
||||||
@ -93,9 +92,6 @@ type Headscale struct {
|
|||||||
|
|
||||||
shutdownChan chan struct{}
|
shutdownChan chan struct{}
|
||||||
pollNetMapStreamWG sync.WaitGroup
|
pollNetMapStreamWG sync.WaitGroup
|
||||||
|
|
||||||
stateUpdateChan chan struct{}
|
|
||||||
cancelStateUpdateChan chan struct{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||||
@ -158,19 +154,14 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||||||
noisePrivateKey: noisePrivateKey,
|
noisePrivateKey: noisePrivateKey,
|
||||||
registrationCache: registrationCache,
|
registrationCache: registrationCache,
|
||||||
pollNetMapStreamWG: sync.WaitGroup{},
|
pollNetMapStreamWG: sync.WaitGroup{},
|
||||||
lastStateChange: xsync.NewMapOf[time.Time](),
|
nodeNotifier: notifier.NewNotifier(),
|
||||||
|
|
||||||
stateUpdateChan: make(chan struct{}),
|
|
||||||
cancelStateUpdateChan: make(chan struct{}),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
go app.watchStateChannel()
|
|
||||||
|
|
||||||
database, err := db.NewHeadscaleDatabase(
|
database, err := db.NewHeadscaleDatabase(
|
||||||
cfg.DBtype,
|
cfg.DBtype,
|
||||||
dbString,
|
dbString,
|
||||||
app.dbDebug,
|
app.dbDebug,
|
||||||
app.stateUpdateChan,
|
app.nodeNotifier,
|
||||||
cfg.IPPrefixes,
|
cfg.IPPrefixes,
|
||||||
cfg.BaseDomain)
|
cfg.BaseDomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -203,7 +194,11 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||||||
|
|
||||||
if cfg.DERP.ServerEnabled {
|
if cfg.DERP.ServerEnabled {
|
||||||
// TODO(kradalby): replace this key with a dedicated DERP key.
|
// TODO(kradalby): replace this key with a dedicated DERP key.
|
||||||
embeddedDERPServer, err := derpServer.NewDERPServer(cfg.ServerURL, key.NodePrivate(*privateKey), &cfg.DERP)
|
embeddedDERPServer, err := derpServer.NewDERPServer(
|
||||||
|
cfg.ServerURL,
|
||||||
|
key.NodePrivate(*privateKey),
|
||||||
|
&cfg.DERP,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -230,10 +225,14 @@ func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
|
|||||||
|
|
||||||
// expireExpiredMachines expires machines that have an explicit expiry set
|
// expireExpiredMachines expires machines that have an explicit expiry set
|
||||||
// after that expiry time has passed.
|
// after that expiry time has passed.
|
||||||
func (h *Headscale) expireExpiredMachines(milliSeconds int64) {
|
func (h *Headscale) expireExpiredMachines(intervalMs int64) {
|
||||||
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
interval := time.Duration(intervalMs) * time.Millisecond
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
|
||||||
|
lastCheck := time.Unix(0, 0)
|
||||||
|
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
h.db.ExpireExpiredMachines(h.getLastStateChange())
|
lastCheck = h.db.ExpireExpiredMachines(lastCheck)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -258,7 +257,7 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
|
|||||||
h.DERPMap.Regions[region.RegionID] = ®ion
|
h.DERPMap.Regions[region.RegionID] = ®ion
|
||||||
}
|
}
|
||||||
|
|
||||||
h.setLastStateChangeToNow()
|
h.nodeNotifier.NotifyAll()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -722,7 +721,7 @@ func (h *Headscale) Serve() error {
|
|||||||
Str("path", aclPath).
|
Str("path", aclPath).
|
||||||
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
||||||
|
|
||||||
h.setLastStateChangeToNow()
|
h.nodeNotifier.NotifyAll()
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
@ -760,10 +759,6 @@ func (h *Headscale) Serve() error {
|
|||||||
// Stop listening (and unlink the socket if unix type):
|
// Stop listening (and unlink the socket if unix type):
|
||||||
socketListener.Close()
|
socketListener.Close()
|
||||||
|
|
||||||
<-h.cancelStateUpdateChan
|
|
||||||
close(h.stateUpdateChan)
|
|
||||||
close(h.cancelStateUpdateChan)
|
|
||||||
|
|
||||||
// Close db connections
|
// Close db connections
|
||||||
err = h.db.Close()
|
err = h.db.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -859,73 +854,6 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): baby steps, make this more robust.
|
|
||||||
func (h *Headscale) watchStateChannel() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-h.stateUpdateChan:
|
|
||||||
h.setLastStateChangeToNow()
|
|
||||||
|
|
||||||
case <-h.cancelStateUpdateChan:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Headscale) setLastStateChangeToNow() {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
now := time.Now().UTC()
|
|
||||||
|
|
||||||
users, err := h.db.ListUsers()
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("failed to fetch all users, failing to update last changed state.")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, user := range users {
|
|
||||||
lastStateUpdate.WithLabelValues(user.Name, "headscale").Set(float64(now.Unix()))
|
|
||||||
if h.lastStateChange == nil {
|
|
||||||
h.lastStateChange = xsync.NewMapOf[time.Time]()
|
|
||||||
}
|
|
||||||
h.lastStateChange.Store(user.Name, now)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Headscale) getLastStateChange(users ...types.User) time.Time {
|
|
||||||
times := []time.Time{}
|
|
||||||
|
|
||||||
// getLastStateChange takes a list of users as a "filter", if no users
|
|
||||||
// are past, then use the entier list of users and look for the last update
|
|
||||||
if len(users) > 0 {
|
|
||||||
for _, user := range users {
|
|
||||||
if lastChange, ok := h.lastStateChange.Load(user.Name); ok {
|
|
||||||
times = append(times, lastChange)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
h.lastStateChange.Range(func(key string, value time.Time) bool {
|
|
||||||
times = append(times, value)
|
|
||||||
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Slice(times, func(i, j int) bool {
|
|
||||||
return times[i].After(times[j])
|
|
||||||
})
|
|
||||||
|
|
||||||
log.Trace().Msgf("Latest times %#v", times)
|
|
||||||
|
|
||||||
if len(times) == 0 {
|
|
||||||
return time.Now().UTC()
|
|
||||||
} else {
|
|
||||||
return times[0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func notFoundHandler(
|
func notFoundHandler(
|
||||||
writer http.ResponseWriter,
|
writer http.ResponseWriter,
|
||||||
req *http.Request,
|
req *http.Request,
|
||||||
|
@ -63,8 +63,6 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
|
|||||||
|
|
||||||
c.Assert(len(machine1.IPAddresses), check.Equals, 1)
|
c.Assert(len(machine1.IPAddresses), check.Equals, 1)
|
||||||
c.Assert(machine1.IPAddresses[0], check.Equals, expected)
|
c.Assert(machine1.IPAddresses[0], check.Equals, expected)
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetMultiIp(c *check.C) {
|
func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||||
@ -153,8 +151,6 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
|
|||||||
|
|
||||||
c.Assert(len(nextIP2), check.Equals, 1)
|
c.Assert(len(nextIP2), check.Equals, 1)
|
||||||
c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String())
|
c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String())
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
|
func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
|
||||||
@ -192,6 +188,4 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
|
|||||||
|
|
||||||
c.Assert(len(ips2), check.Equals, 1)
|
c.Assert(len(ips2), check.Equals, 1)
|
||||||
c.Assert(ips2[0].String(), check.Equals, expected.String())
|
c.Assert(ips2[0].String(), check.Equals, expected.String())
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
|
||||||
}
|
}
|
||||||
|
@ -22,8 +22,6 @@ func (*Suite) TestCreateAPIKey(c *check.C) {
|
|||||||
keys, err := db.ListAPIKeys()
|
keys, err := db.ListAPIKeys()
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(keys), check.Equals, 1)
|
c.Assert(len(keys), check.Equals, 1)
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestAPIKeyDoesNotExist(c *check.C) {
|
func (*Suite) TestAPIKeyDoesNotExist(c *check.C) {
|
||||||
@ -41,8 +39,6 @@ func (*Suite) TestValidateAPIKeyOk(c *check.C) {
|
|||||||
valid, err := db.ValidateAPIKey(apiKeyStr)
|
valid, err := db.ValidateAPIKey(apiKeyStr)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(valid, check.Equals, true)
|
c.Assert(valid, check.Equals, true)
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestValidateAPIKeyNotOk(c *check.C) {
|
func (*Suite) TestValidateAPIKeyNotOk(c *check.C) {
|
||||||
@ -71,8 +67,6 @@ func (*Suite) TestValidateAPIKeyNotOk(c *check.C) {
|
|||||||
validWithErr, err := db.ValidateAPIKey("produceerrorkey")
|
validWithErr, err := db.ValidateAPIKey("produceerrorkey")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
c.Assert(validWithErr, check.Equals, false)
|
c.Assert(validWithErr, check.Equals, false)
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestExpireAPIKey(c *check.C) {
|
func (*Suite) TestExpireAPIKey(c *check.C) {
|
||||||
@ -92,6 +86,4 @@ func (*Suite) TestExpireAPIKey(c *check.C) {
|
|||||||
notValid, err := db.ValidateAPIKey(apiKeyStr)
|
notValid, err := db.ValidateAPIKey(apiKeyStr)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(notValid, check.Equals, false)
|
c.Assert(notValid, check.Equals, false)
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
|
||||||
}
|
}
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/glebarez/sqlite"
|
"github.com/glebarez/sqlite"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
@ -37,7 +38,7 @@ type KV struct {
|
|||||||
|
|
||||||
type HSDatabase struct {
|
type HSDatabase struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
notifyStateChan chan<- struct{}
|
notifier *notifier.Notifier
|
||||||
|
|
||||||
ipAllocationMutex sync.Mutex
|
ipAllocationMutex sync.Mutex
|
||||||
|
|
||||||
@ -50,7 +51,7 @@ type HSDatabase struct {
|
|||||||
func NewHeadscaleDatabase(
|
func NewHeadscaleDatabase(
|
||||||
dbType, connectionAddr string,
|
dbType, connectionAddr string,
|
||||||
debug bool,
|
debug bool,
|
||||||
notifyStateChan chan<- struct{},
|
notifier *notifier.Notifier,
|
||||||
ipPrefixes []netip.Prefix,
|
ipPrefixes []netip.Prefix,
|
||||||
baseDomain string,
|
baseDomain string,
|
||||||
) (*HSDatabase, error) {
|
) (*HSDatabase, error) {
|
||||||
@ -61,7 +62,7 @@ func NewHeadscaleDatabase(
|
|||||||
|
|
||||||
db := HSDatabase{
|
db := HSDatabase{
|
||||||
db: dbConn,
|
db: dbConn,
|
||||||
notifyStateChan: notifyStateChan,
|
notifier: notifier,
|
||||||
|
|
||||||
ipPrefixes: ipPrefixes,
|
ipPrefixes: ipPrefixes,
|
||||||
baseDomain: baseDomain,
|
baseDomain: baseDomain,
|
||||||
@ -297,10 +298,6 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) notifyStateChange() {
|
|
||||||
hsdb.notifyStateChan <- struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// getValue returns the value for the given key in KV.
|
// getValue returns the value for the given key in KV.
|
||||||
func (hsdb *HSDatabase) getValue(key string) (string, error) {
|
func (hsdb *HSDatabase) getValue(key string) (string, error) {
|
||||||
var row KV
|
var row KV
|
||||||
|
@ -218,7 +218,7 @@ func (hsdb *HSDatabase) SetTags(
|
|||||||
}
|
}
|
||||||
machine.ForcedTags = newTags
|
machine.ForcedTags = newTags
|
||||||
|
|
||||||
hsdb.notifyStateChange()
|
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
|
||||||
|
|
||||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||||
return fmt.Errorf("failed to update tags for machine in the database: %w", err)
|
return fmt.Errorf("failed to update tags for machine in the database: %w", err)
|
||||||
@ -232,7 +232,7 @@ func (hsdb *HSDatabase) ExpireMachine(machine *types.Machine) error {
|
|||||||
now := time.Now()
|
now := time.Now()
|
||||||
machine.Expiry = &now
|
machine.Expiry = &now
|
||||||
|
|
||||||
hsdb.notifyStateChange()
|
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
|
||||||
|
|
||||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||||
return fmt.Errorf("failed to expire machine in the database: %w", err)
|
return fmt.Errorf("failed to expire machine in the database: %w", err)
|
||||||
@ -259,7 +259,7 @@ func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) er
|
|||||||
}
|
}
|
||||||
machine.GivenName = newName
|
machine.GivenName = newName
|
||||||
|
|
||||||
hsdb.notifyStateChange()
|
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
|
||||||
|
|
||||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||||
return fmt.Errorf("failed to rename machine in the database: %w", err)
|
return fmt.Errorf("failed to rename machine in the database: %w", err)
|
||||||
@ -275,7 +275,7 @@ func (hsdb *HSDatabase) RefreshMachine(machine *types.Machine, expiry time.Time)
|
|||||||
machine.LastSuccessfulUpdate = &now
|
machine.LastSuccessfulUpdate = &now
|
||||||
machine.Expiry = &expiry
|
machine.Expiry = &expiry
|
||||||
|
|
||||||
hsdb.notifyStateChange()
|
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
|
||||||
|
|
||||||
if err := hsdb.db.Save(machine).Error; err != nil {
|
if err := hsdb.db.Save(machine).Error; err != nil {
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
@ -323,32 +323,6 @@ func (hsdb *HSDatabase) HardDeleteMachine(machine *types.Machine) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) IsOutdated(machine *types.Machine, lastChange time.Time) bool {
|
|
||||||
if err := hsdb.UpdateMachineFromDatabase(machine); err != nil {
|
|
||||||
// It does not seem meaningful to propagate this error as the end result
|
|
||||||
// will have to be that the machine has to be considered outdated.
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the last update from all headscale users to compare with our nodes
|
|
||||||
// last update.
|
|
||||||
// TODO(kradalby): Only request updates from users where we can talk to nodes
|
|
||||||
// This would mostly be for a bit of performance, and can be calculated based on
|
|
||||||
// ACLs.
|
|
||||||
lastUpdate := machine.CreatedAt
|
|
||||||
if machine.LastSuccessfulUpdate != nil {
|
|
||||||
lastUpdate = *machine.LastSuccessfulUpdate
|
|
||||||
}
|
|
||||||
log.Trace().
|
|
||||||
Caller().
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Time("last_successful_update", lastChange).
|
|
||||||
Time("last_state_change", lastUpdate).
|
|
||||||
Msgf("Checking if %s is missing updates", machine.Hostname)
|
|
||||||
|
|
||||||
return lastUpdate.Before(lastChange)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hsdb *HSDatabase) RegisterMachineFromAuthCallback(
|
func (hsdb *HSDatabase) RegisterMachineFromAuthCallback(
|
||||||
cache *cache.Cache,
|
cache *cache.Cache,
|
||||||
nodeKeyStr string,
|
nodeKeyStr string,
|
||||||
@ -626,7 +600,7 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hsdb.notifyStateChange()
|
hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -723,17 +697,22 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati
|
|||||||
}
|
}
|
||||||
|
|
||||||
if expiredFound {
|
if expiredFound {
|
||||||
hsdb.notifyStateChange()
|
hsdb.notifier.NotifyAll()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) ExpireExpiredMachines(lastChange time.Time) {
|
func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time {
|
||||||
|
// use the time of the start of the function to ensure we
|
||||||
|
// dont miss some machines by returning it _after_ we have
|
||||||
|
// checked everything.
|
||||||
|
started := time.Now()
|
||||||
|
|
||||||
users, err := hsdb.ListUsers()
|
users, err := hsdb.ListUsers()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Error listing users")
|
log.Error().Err(err).Msg("Error listing users")
|
||||||
|
|
||||||
return
|
return time.Unix(0, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
@ -744,13 +723,13 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastChange time.Time) {
|
|||||||
Str("user", user.Name).
|
Str("user", user.Name).
|
||||||
Msg("Error listing machines in user")
|
Msg("Error listing machines in user")
|
||||||
|
|
||||||
return
|
return time.Unix(0, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
expiredFound := false
|
expiredFound := false
|
||||||
for index, machine := range machines {
|
for index, machine := range machines {
|
||||||
if machine.IsExpired() &&
|
if machine.IsExpired() &&
|
||||||
machine.Expiry.After(lastChange) {
|
machine.Expiry.After(lastCheck) {
|
||||||
expiredFound = true
|
expiredFound = true
|
||||||
|
|
||||||
err := hsdb.ExpireMachine(&machines[index])
|
err := hsdb.ExpireMachine(&machines[index])
|
||||||
@ -770,7 +749,9 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastChange time.Time) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if expiredFound {
|
if expiredFound {
|
||||||
hsdb.notifyStateChange()
|
hsdb.notifier.NotifyAll()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return started
|
||||||
}
|
}
|
||||||
|
@ -39,8 +39,6 @@ func (s *Suite) TestGetMachine(c *check.C) {
|
|||||||
|
|
||||||
_, err = db.GetMachine("test", "testmachine")
|
_, err = db.GetMachine("test", "testmachine")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetMachineByID(c *check.C) {
|
func (s *Suite) TestGetMachineByID(c *check.C) {
|
||||||
@ -67,8 +65,6 @@ func (s *Suite) TestGetMachineByID(c *check.C) {
|
|||||||
|
|
||||||
_, err = db.GetMachineByID(0)
|
_, err = db.GetMachineByID(0)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetMachineByNodeKey(c *check.C) {
|
func (s *Suite) TestGetMachineByNodeKey(c *check.C) {
|
||||||
@ -98,8 +94,6 @@ func (s *Suite) TestGetMachineByNodeKey(c *check.C) {
|
|||||||
|
|
||||||
_, err = db.GetMachineByNodeKey(nodeKey.Public())
|
_, err = db.GetMachineByNodeKey(nodeKey.Public())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) {
|
func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) {
|
||||||
@ -131,8 +125,6 @@ func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) {
|
|||||||
|
|
||||||
_, err = db.GetMachineByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public())
|
_, err = db.GetMachineByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public())
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestDeleteMachine(c *check.C) {
|
func (s *Suite) TestDeleteMachine(c *check.C) {
|
||||||
@ -155,8 +147,6 @@ func (s *Suite) TestDeleteMachine(c *check.C) {
|
|||||||
|
|
||||||
_, err = db.GetMachine(user.Name, "testmachine")
|
_, err = db.GetMachine(user.Name, "testmachine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestHardDeleteMachine(c *check.C) {
|
func (s *Suite) TestHardDeleteMachine(c *check.C) {
|
||||||
@ -179,8 +169,6 @@ func (s *Suite) TestHardDeleteMachine(c *check.C) {
|
|||||||
|
|
||||||
_, err = db.GetMachine(user.Name, "testmachine3")
|
_, err = db.GetMachine(user.Name, "testmachine3")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestListPeers(c *check.C) {
|
func (s *Suite) TestListPeers(c *check.C) {
|
||||||
@ -217,8 +205,6 @@ func (s *Suite) TestListPeers(c *check.C) {
|
|||||||
c.Assert(peersOfMachine0[0].Hostname, check.Equals, "testmachine2")
|
c.Assert(peersOfMachine0[0].Hostname, check.Equals, "testmachine2")
|
||||||
c.Assert(peersOfMachine0[5].Hostname, check.Equals, "testmachine7")
|
c.Assert(peersOfMachine0[5].Hostname, check.Equals, "testmachine7")
|
||||||
c.Assert(peersOfMachine0[8].Hostname, check.Equals, "testmachine10")
|
c.Assert(peersOfMachine0[8].Hostname, check.Equals, "testmachine10")
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||||
@ -312,8 +298,6 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
|||||||
c.Assert(peersOfAdminMachine[0].Hostname, check.Equals, "testmachine2")
|
c.Assert(peersOfAdminMachine[0].Hostname, check.Equals, "testmachine2")
|
||||||
c.Assert(peersOfAdminMachine[2].Hostname, check.Equals, "testmachine4")
|
c.Assert(peersOfAdminMachine[2].Hostname, check.Equals, "testmachine4")
|
||||||
c.Assert(peersOfAdminMachine[5].Hostname, check.Equals, "testmachine7")
|
c.Assert(peersOfAdminMachine[5].Hostname, check.Equals, "testmachine7")
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestExpireMachine(c *check.C) {
|
func (s *Suite) TestExpireMachine(c *check.C) {
|
||||||
@ -349,8 +333,6 @@ func (s *Suite) TestExpireMachine(c *check.C) {
|
|||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(machineFromDB.IsExpired(), check.Equals, true)
|
c.Assert(machineFromDB.IsExpired(), check.Equals, true)
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(1))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) {
|
func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) {
|
||||||
@ -372,8 +354,6 @@ func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) {
|
|||||||
for i := range deserialized {
|
for i := range deserialized {
|
||||||
c.Assert(deserialized[i], check.Equals, input[i])
|
c.Assert(deserialized[i], check.Equals, input[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGenerateGivenName(c *check.C) {
|
func (s *Suite) TestGenerateGivenName(c *check.C) {
|
||||||
@ -418,8 +398,6 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
|
|||||||
comment = check.Commentf("Unique users, unique machines, same hostname, conflict")
|
comment = check.Commentf("Unique users, unique machines, same hostname, conflict")
|
||||||
c.Assert(err, check.IsNil, comment)
|
c.Assert(err, check.IsNil, comment)
|
||||||
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment)
|
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment)
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestSetTags(c *check.C) {
|
func (s *Suite) TestSetTags(c *check.C) {
|
||||||
@ -463,8 +441,6 @@ func (s *Suite) TestSetTags(c *check.C) {
|
|||||||
check.DeepEquals,
|
check.DeepEquals,
|
||||||
types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}),
|
types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}),
|
||||||
)
|
)
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(2))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHeadscale_generateGivenName(t *testing.T) {
|
func TestHeadscale_generateGivenName(t *testing.T) {
|
||||||
@ -655,6 +631,4 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
|
|||||||
enabledRoutes, err := db.GetEnabledRoutes(machine0ByID)
|
enabledRoutes, err := db.GetEnabledRoutes(machine0ByID)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(enabledRoutes, check.HasLen, 4)
|
c.Assert(enabledRoutes, check.HasLen, 4)
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(4))
|
|
||||||
}
|
}
|
||||||
|
@ -161,8 +161,6 @@ func (*Suite) TestEphemeralKey(c *check.C) {
|
|||||||
// The machine record should have been deleted
|
// The machine record should have been deleted
|
||||||
_, err = db.GetMachine("test7", "testest")
|
_, err = db.GetMachine("test7", "testest")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(1))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestExpirePreauthKey(c *check.C) {
|
func (*Suite) TestExpirePreauthKey(c *check.C) {
|
||||||
|
@ -374,7 +374,7 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if routesChanged {
|
if routesChanged {
|
||||||
hsdb.notifyStateChange()
|
hsdb.notifier.NotifyAll()
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -52,8 +52,6 @@ func (s *Suite) TestGetRoutes(c *check.C) {
|
|||||||
|
|
||||||
err = db.enableRoutes(&machine, "10.0.0.0/24")
|
err = db.enableRoutes(&machine, "10.0.0.0/24")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(0))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
||||||
@ -129,8 +127,6 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
|||||||
enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&machine)
|
enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&machine)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2)
|
c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2)
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(3))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
||||||
@ -215,8 +211,6 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
|
|||||||
routes, err = db.GetMachinePrimaryRoutes(&machine2)
|
routes, err = db.GetMachinePrimaryRoutes(&machine2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(routes), check.Equals, 0)
|
c.Assert(len(routes), check.Equals, 0)
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(3))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestSubnetFailover(c *check.C) {
|
func (s *Suite) TestSubnetFailover(c *check.C) {
|
||||||
@ -359,8 +353,6 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
|
|||||||
routes, err = db.GetMachinePrimaryRoutes(&machine2)
|
routes, err = db.GetMachinePrimaryRoutes(&machine2)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(routes), check.Equals, 2)
|
c.Assert(len(routes), check.Equals, 2)
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(6))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestDeleteRoutes(c *check.C) {
|
func (s *Suite) TestDeleteRoutes(c *check.C) {
|
||||||
@ -420,6 +412,4 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
|
|||||||
enabledRoutes1, err := db.GetEnabledRoutes(&machine1)
|
enabledRoutes1, err := db.GetEnabledRoutes(&machine1)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(enabledRoutes1), check.Equals, 1)
|
c.Assert(len(enabledRoutes1), check.Equals, 1)
|
||||||
|
|
||||||
c.Assert(channelUpdates, check.Equals, int32(2))
|
|
||||||
}
|
}
|
||||||
|
@ -3,9 +3,9 @@ package db
|
|||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||||
"gopkg.in/check.v1"
|
"gopkg.in/check.v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -20,14 +20,9 @@ type Suite struct{}
|
|||||||
var (
|
var (
|
||||||
tmpDir string
|
tmpDir string
|
||||||
db *HSDatabase
|
db *HSDatabase
|
||||||
|
|
||||||
// channelUpdates counts the number of times
|
|
||||||
// either of the channels was notified.
|
|
||||||
channelUpdates int32
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Suite) SetUpTest(c *check.C) {
|
func (s *Suite) SetUpTest(c *check.C) {
|
||||||
atomic.StoreInt32(&channelUpdates, 0)
|
|
||||||
s.ResetDB(c)
|
s.ResetDB(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -35,13 +30,6 @@ func (s *Suite) TearDownTest(c *check.C) {
|
|||||||
os.RemoveAll(tmpDir)
|
os.RemoveAll(tmpDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
func notificationSink(c <-chan struct{}) {
|
|
||||||
for {
|
|
||||||
<-c
|
|
||||||
atomic.AddInt32(&channelUpdates, 1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Suite) ResetDB(c *check.C) {
|
func (s *Suite) ResetDB(c *check.C) {
|
||||||
if len(tmpDir) != 0 {
|
if len(tmpDir) != 0 {
|
||||||
os.RemoveAll(tmpDir)
|
os.RemoveAll(tmpDir)
|
||||||
@ -52,15 +40,11 @@ func (s *Suite) ResetDB(c *check.C) {
|
|||||||
c.Fatal(err)
|
c.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sink := make(chan struct{})
|
|
||||||
|
|
||||||
go notificationSink(sink)
|
|
||||||
|
|
||||||
db, err = NewHeadscaleDatabase(
|
db, err = NewHeadscaleDatabase(
|
||||||
"sqlite3",
|
"sqlite3",
|
||||||
tmpDir+"/headscale_test.db",
|
tmpDir+"/headscale_test.db",
|
||||||
false,
|
false,
|
||||||
sink,
|
notifier.NewNotifier(),
|
||||||
[]netip.Prefix{
|
[]netip.Prefix{
|
||||||
netip.MustParsePrefix("10.27.0.0/23"),
|
netip.MustParsePrefix("10.27.0.0/23"),
|
||||||
},
|
},
|
||||||
|
@ -10,12 +10,6 @@ const prometheusNamespace = "headscale"
|
|||||||
var (
|
var (
|
||||||
// This is a high cardinality metric (user x machines), we might want to make this
|
// This is a high cardinality metric (user x machines), we might want to make this
|
||||||
// configurable/opt-in in the future.
|
// configurable/opt-in in the future.
|
||||||
lastStateUpdate = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
|
||||||
Namespace: prometheusNamespace,
|
|
||||||
Name: "last_update_seconds",
|
|
||||||
Help: "Time stamp in unix time when a machine or headscale was updated",
|
|
||||||
}, []string{"user", "machine"})
|
|
||||||
|
|
||||||
machineRegistrations = promauto.NewCounterVec(prometheus.CounterOpts{
|
machineRegistrations = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||||
Namespace: prometheusNamespace,
|
Namespace: prometheusNamespace,
|
||||||
Name: "machine_registrations_total",
|
Name: "machine_registrations_total",
|
||||||
@ -33,9 +27,4 @@ var (
|
|||||||
Help: "The number of calls/messages issued on a specific nodes update channel",
|
Help: "The number of calls/messages issued on a specific nodes update channel",
|
||||||
}, []string{"user", "machine", "status"})
|
}, []string{"user", "machine", "status"})
|
||||||
// TODO(kradalby): This is very debugging, we might want to remove it.
|
// TODO(kradalby): This is very debugging, we might want to remove it.
|
||||||
updateRequestsReceivedOnChannel = promauto.NewCounterVec(prometheus.CounterOpts{
|
|
||||||
Namespace: prometheusNamespace,
|
|
||||||
Name: "update_request_received_on_channel_total",
|
|
||||||
Help: "The number of update requests received on an update channel",
|
|
||||||
}, []string{"user", "machine"})
|
|
||||||
)
|
)
|
||||||
|
55
hscontrol/notifier/notifier.go
Normal file
55
hscontrol/notifier/notifier.go
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
package notifier
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Notifier struct {
|
||||||
|
l sync.RWMutex
|
||||||
|
nodes map[string]chan<- struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNotifier() *Notifier {
|
||||||
|
return &Notifier{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) AddNode(machineKey string, c chan<- struct{}) {
|
||||||
|
n.l.Lock()
|
||||||
|
defer n.l.Unlock()
|
||||||
|
|
||||||
|
if n.nodes == nil {
|
||||||
|
n.nodes = make(map[string]chan<- struct{})
|
||||||
|
}
|
||||||
|
|
||||||
|
n.nodes[machineKey] = c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) RemoveNode(machineKey string) {
|
||||||
|
n.l.Lock()
|
||||||
|
defer n.l.Unlock()
|
||||||
|
|
||||||
|
if n.nodes == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(n.nodes, machineKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) NotifyAll() {
|
||||||
|
n.NotifyWithIgnore()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) NotifyWithIgnore(ignore ...string) {
|
||||||
|
n.l.RLock()
|
||||||
|
defer n.l.RUnlock()
|
||||||
|
|
||||||
|
for key, c := range n.nodes {
|
||||||
|
if util.IsStringInSlice(ignore, key) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
c <- struct{}{}
|
||||||
|
}
|
||||||
|
}
|
@ -21,6 +21,38 @@ type contextKey string
|
|||||||
|
|
||||||
const machineNameContextKey = contextKey("machineName")
|
const machineNameContextKey = contextKey("machineName")
|
||||||
|
|
||||||
|
type UpdateNode func()
|
||||||
|
|
||||||
|
func logPollFunc(
|
||||||
|
mapRequest tailcfg.MapRequest,
|
||||||
|
machine *types.Machine,
|
||||||
|
isNoise bool,
|
||||||
|
) (func(string), func(error, string)) {
|
||||||
|
return func(msg string) {
|
||||||
|
log.Info().
|
||||||
|
Caller().
|
||||||
|
Bool("noise", isNoise).
|
||||||
|
Bool("readOnly", mapRequest.ReadOnly).
|
||||||
|
Bool("omitPeers", mapRequest.OmitPeers).
|
||||||
|
Bool("stream", mapRequest.Stream).
|
||||||
|
Str("node_key", machine.NodeKey).
|
||||||
|
Str("machine", machine.Hostname).
|
||||||
|
Msg(msg)
|
||||||
|
},
|
||||||
|
func(err error, msg string) {
|
||||||
|
log.Error().
|
||||||
|
Caller().
|
||||||
|
Bool("noise", isNoise).
|
||||||
|
Bool("readOnly", mapRequest.ReadOnly).
|
||||||
|
Bool("omitPeers", mapRequest.OmitPeers).
|
||||||
|
Bool("stream", mapRequest.Stream).
|
||||||
|
Str("node_key", machine.NodeKey).
|
||||||
|
Str("machine", machine.Hostname).
|
||||||
|
Err(err).
|
||||||
|
Msg(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// handlePoll is the common code for the legacy and Noise protocols to
|
// handlePoll is the common code for the legacy and Noise protocols to
|
||||||
// managed the poll loop.
|
// managed the poll loop.
|
||||||
func (h *Headscale) handlePoll(
|
func (h *Headscale) handlePoll(
|
||||||
@ -30,6 +62,10 @@ func (h *Headscale) handlePoll(
|
|||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
isNoise bool,
|
isNoise bool,
|
||||||
) {
|
) {
|
||||||
|
logInfo, logErr := logPollFunc(mapRequest, machine, isNoise)
|
||||||
|
|
||||||
|
// TODO(kradalby): This is a stepping stone, mapper should be initiated once
|
||||||
|
// per client or something similar
|
||||||
mapp := mapper.NewMapper(
|
mapp := mapper.NewMapper(
|
||||||
h.db,
|
h.db,
|
||||||
h.privateKey2019,
|
h.privateKey2019,
|
||||||
@ -48,11 +84,7 @@ func (h *Headscale) handlePoll(
|
|||||||
|
|
||||||
err := h.db.ProcessMachineRoutes(machine)
|
err := h.db.ProcessMachineRoutes(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
logErr(err, "Error processing machine routes")
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Msg("Error processing machine routes")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// update ACLRules with peer informations (to update server tags if necessary)
|
// update ACLRules with peer informations (to update server tags if necessary)
|
||||||
@ -60,12 +92,7 @@ func (h *Headscale) handlePoll(
|
|||||||
// update routes with peer information
|
// update routes with peer information
|
||||||
err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, machine)
|
err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
logErr(err, "Error running auto approved routes")
|
||||||
Caller().
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Err(err).
|
|
||||||
Msg("Error running auto approved routes")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -83,13 +110,7 @@ func (h *Headscale) handlePoll(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := h.db.MachineSave(machine); err != nil {
|
if err := h.db.MachineSave(machine); err != nil {
|
||||||
log.Error().
|
logErr(err, "Failed to persist/update machine in the database")
|
||||||
Str("handler", "PollNetMap").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("node_key", machine.NodeKey).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Err(err).
|
|
||||||
Msg("Failed to persist/update machine in the database")
|
|
||||||
http.Error(writer, "", http.StatusInternalServerError)
|
http.Error(writer, "", http.StatusInternalServerError)
|
||||||
|
|
||||||
return
|
return
|
||||||
@ -97,13 +118,7 @@ func (h *Headscale) handlePoll(
|
|||||||
|
|
||||||
mapResp, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy)
|
mapResp, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
logErr(err, "Failed to create MapResponse")
|
||||||
Str("handler", "PollNetMap").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("node_key", machine.NodeKey).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Err(err).
|
|
||||||
Msg("Failed to get Map response")
|
|
||||||
http.Error(writer, "", http.StatusInternalServerError)
|
http.Error(writer, "", http.StatusInternalServerError)
|
||||||
|
|
||||||
return
|
return
|
||||||
@ -114,30 +129,16 @@ func (h *Headscale) handlePoll(
|
|||||||
// empty endpoints to peers)
|
// empty endpoints to peers)
|
||||||
|
|
||||||
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
|
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
|
||||||
log.Debug().
|
logInfo("Client map request processed")
|
||||||
Str("handler", "PollNetMap").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Bool("readOnly", mapRequest.ReadOnly).
|
|
||||||
Bool("omitPeers", mapRequest.OmitPeers).
|
|
||||||
Bool("stream", mapRequest.Stream).
|
|
||||||
Msg("Client map request processed")
|
|
||||||
|
|
||||||
if mapRequest.ReadOnly {
|
if mapRequest.ReadOnly {
|
||||||
log.Info().
|
logInfo("Client is starting up. Probably interested in a DERP map")
|
||||||
Str("handler", "PollNetMap").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Msg("Client is starting up. Probably interested in a DERP map")
|
|
||||||
|
|
||||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
writer.WriteHeader(http.StatusOK)
|
writer.WriteHeader(http.StatusOK)
|
||||||
_, err := writer.Write(mapResp)
|
_, err := writer.Write(mapResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
logErr(err, "Failed to write response")
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("Failed to write response")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if f, ok := writer.(http.Flusher); ok {
|
if f, ok := writer.(http.Flusher); ok {
|
||||||
@ -147,48 +148,22 @@ func (h *Headscale) handlePoll(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// There has been an update to _any_ of the nodes that the other nodes would
|
|
||||||
// need to know about
|
|
||||||
h.setLastStateChangeToNow()
|
|
||||||
|
|
||||||
// The request is not ReadOnly, so we need to set up channels for updating
|
|
||||||
// peers via longpoll
|
|
||||||
|
|
||||||
// Only create update channel if it has not been created
|
|
||||||
log.Trace().
|
|
||||||
Caller().
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Msg("Loading or creating update channel")
|
|
||||||
|
|
||||||
const chanSize = 8
|
|
||||||
updateChan := make(chan struct{}, chanSize)
|
|
||||||
|
|
||||||
pollDataChan := make(chan []byte, chanSize)
|
|
||||||
defer closeChanWithLog(pollDataChan, machine.Hostname, "pollDataChan")
|
|
||||||
|
|
||||||
keepAliveChan := make(chan []byte)
|
|
||||||
|
|
||||||
if mapRequest.OmitPeers && !mapRequest.Stream {
|
if mapRequest.OmitPeers && !mapRequest.Stream {
|
||||||
log.Info().
|
logInfo("Client sent endpoint update and is ok with a response without peer list")
|
||||||
Str("handler", "PollNetMap").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Msg("Client sent endpoint update and is ok with a response without peer list")
|
|
||||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
writer.WriteHeader(http.StatusOK)
|
writer.WriteHeader(http.StatusOK)
|
||||||
_, err := writer.Write(mapResp)
|
_, err := writer.Write(mapResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
logErr(err, "Failed to write response")
|
||||||
Caller().
|
|
||||||
Err(err).
|
|
||||||
Msg("Failed to write response")
|
|
||||||
}
|
}
|
||||||
// It sounds like we should update the nodes when we have received a endpoint update
|
// It sounds like we should update the nodes when we have received a endpoint update
|
||||||
// even tho the comments in the tailscale code dont explicitly say so.
|
// even tho the comments in the tailscale code dont explicitly say so.
|
||||||
updateRequestsFromNode.WithLabelValues(machine.User.Name, machine.Hostname, "endpoint-update").
|
updateRequestsFromNode.WithLabelValues(machine.User.Name, machine.Hostname, "endpoint-update").
|
||||||
Inc()
|
Inc()
|
||||||
updateChan <- struct{}{}
|
|
||||||
|
// Tell all the other nodes about the new endpoint, but dont update ourselves.
|
||||||
|
h.nodeNotifier.NotifyWithIgnore(machine.MachineKey)
|
||||||
|
|
||||||
return
|
return
|
||||||
} else if mapRequest.OmitPeers && mapRequest.Stream {
|
} else if mapRequest.OmitPeers && mapRequest.Stream {
|
||||||
@ -202,43 +177,32 @@ func (h *Headscale) handlePoll(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().
|
logInfo("Sending initial map")
|
||||||
Str("handler", "PollNetMap").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Msg("Client is ready to access the tailnet")
|
|
||||||
log.Info().
|
|
||||||
Str("handler", "PollNetMap").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Msg("Sending initial map")
|
|
||||||
pollDataChan <- mapResp
|
|
||||||
|
|
||||||
log.Info().
|
// Send the client an update to make sure we send an initial mapresponse
|
||||||
Str("handler", "PollNetMap").
|
_, err = writer.Write(mapResp)
|
||||||
Bool("noise", isNoise).
|
if err != nil {
|
||||||
Str("machine", machine.Hostname).
|
logErr(err, "Could not write the map response")
|
||||||
Msg("Notifying peers")
|
|
||||||
updateRequestsFromNode.WithLabelValues(machine.User.Name, machine.Hostname, "full-update").
|
return
|
||||||
Inc()
|
}
|
||||||
updateChan <- struct{}{}
|
|
||||||
|
if flusher, ok := writer.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
} else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
h.pollNetMapStream(
|
h.pollNetMapStream(
|
||||||
writer,
|
writer,
|
||||||
ctx,
|
ctx,
|
||||||
machine,
|
machine,
|
||||||
|
mapp,
|
||||||
mapRequest,
|
mapRequest,
|
||||||
pollDataChan,
|
|
||||||
keepAliveChan,
|
|
||||||
updateChan,
|
|
||||||
isNoise,
|
isNoise,
|
||||||
)
|
)
|
||||||
|
|
||||||
log.Trace().
|
logInfo("Finished stream, closing PollNetMap session")
|
||||||
Str("handler", "PollNetMap").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Msg("Finished stream, closing PollNetMap session")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// pollNetMapStream stream logic for /machine/map,
|
// pollNetMapStream stream logic for /machine/map,
|
||||||
@ -247,23 +211,16 @@ func (h *Headscale) pollNetMapStream(
|
|||||||
writer http.ResponseWriter,
|
writer http.ResponseWriter,
|
||||||
ctxReq context.Context,
|
ctxReq context.Context,
|
||||||
machine *types.Machine,
|
machine *types.Machine,
|
||||||
|
mapp *mapper.Mapper,
|
||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
pollDataChan chan []byte,
|
|
||||||
keepAliveChan chan []byte,
|
|
||||||
updateChan chan struct{},
|
|
||||||
isNoise bool,
|
isNoise bool,
|
||||||
) {
|
) {
|
||||||
// TODO(kradalby): This is a stepping stone, mapper should be initiated once
|
logInfo, logErr := logPollFunc(mapRequest, machine, isNoise)
|
||||||
// per client or something similar
|
|
||||||
mapp := mapper.NewMapper(h.db,
|
keepAliveTicker := time.NewTicker(keepAliveInterval)
|
||||||
h.privateKey2019,
|
|
||||||
isNoise,
|
const chanSize = 8
|
||||||
h.DERPMap,
|
updateChan := make(chan struct{}, chanSize)
|
||||||
h.cfg.BaseDomain,
|
|
||||||
h.cfg.DNSConfig,
|
|
||||||
h.cfg.LogTail.Enabled,
|
|
||||||
h.cfg.RandomizeClientPort,
|
|
||||||
)
|
|
||||||
|
|
||||||
h.pollNetMapStreamWG.Add(1)
|
h.pollNetMapStreamWG.Add(1)
|
||||||
defer h.pollNetMapStreamWG.Done()
|
defer h.pollNetMapStreamWG.Done()
|
||||||
@ -273,447 +230,93 @@ func (h *Headscale) pollNetMapStream(
|
|||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
go h.scheduledPollWorker(
|
// Register the node's update channel
|
||||||
ctx,
|
h.nodeNotifier.AddNode(machine.MachineKey, updateChan)
|
||||||
updateChan,
|
defer h.nodeNotifier.RemoveNode(machine.MachineKey)
|
||||||
keepAliveChan,
|
defer closeChanWithLog(updateChan, machine.Hostname, "updateChan")
|
||||||
mapRequest,
|
|
||||||
machine,
|
|
||||||
isNoise,
|
|
||||||
)
|
|
||||||
|
|
||||||
log.Trace().
|
|
||||||
Str("handler", "pollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Msg("Waiting for data to stream...")
|
|
||||||
|
|
||||||
log.Trace().
|
|
||||||
Str("handler", "pollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan)
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case data := <-pollDataChan:
|
case <-keepAliveTicker.C:
|
||||||
log.Trace().
|
data, err := mapp.CreateKeepAliveResponse(mapRequest, machine)
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Str("channel", "pollData").
|
|
||||||
Int("bytes", len(data)).
|
|
||||||
Msg("Sending data received via pollData channel")
|
|
||||||
_, err := writer.Write(data)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
logErr(err, "Error generating the keep alive msg")
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Str("channel", "pollData").
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot write data")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
flusher, ok := writer.(http.Flusher)
|
|
||||||
if !ok {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Str("channel", "pollData").
|
|
||||||
Msg("Cannot cast writer to http.Flusher")
|
|
||||||
} else {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Trace().
|
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Str("channel", "pollData").
|
|
||||||
Int("bytes", len(data)).
|
|
||||||
Msg("Data from pollData channel written successfully")
|
|
||||||
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
|
|
||||||
// when an outdated machine object is kept alive, e.g. db is update from
|
|
||||||
// command line, but then overwritten.
|
|
||||||
err = h.db.UpdateMachineFromDatabase(machine)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Str("channel", "pollData").
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot update machine from database")
|
|
||||||
|
|
||||||
// client has been removed from database
|
|
||||||
// since the stream opened, terminate connection.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
now := time.Now().UTC()
|
|
||||||
machine.LastSeen = &now
|
|
||||||
|
|
||||||
lastStateUpdate.WithLabelValues(machine.User.Name, machine.Hostname).
|
|
||||||
Set(float64(now.Unix()))
|
|
||||||
machine.LastSuccessfulUpdate = &now
|
|
||||||
|
|
||||||
err = h.db.TouchMachine(machine)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Str("channel", "pollData").
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot update machine LastSuccessfulUpdate")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Trace().
|
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Str("channel", "pollData").
|
|
||||||
Int("bytes", len(data)).
|
|
||||||
Msg("Machine entry in database updated successfully after sending data")
|
|
||||||
|
|
||||||
case data := <-keepAliveChan:
|
|
||||||
log.Trace().
|
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Str("channel", "keepAlive").
|
|
||||||
Int("bytes", len(data)).
|
|
||||||
Msg("Sending keep alive message")
|
|
||||||
_, err := writer.Write(data)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Str("channel", "keepAlive").
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot write keep alive message")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
flusher, ok := writer.(http.Flusher)
|
|
||||||
if !ok {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Str("channel", "keepAlive").
|
|
||||||
Msg("Cannot cast writer to http.Flusher")
|
|
||||||
} else {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Trace().
|
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Str("channel", "keepAlive").
|
|
||||||
Int("bytes", len(data)).
|
|
||||||
Msg("Keep alive sent successfully")
|
|
||||||
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
|
|
||||||
// when an outdated machine object is kept alive, e.g. db is update from
|
|
||||||
// command line, but then overwritten.
|
|
||||||
err = h.db.UpdateMachineFromDatabase(machine)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Str("channel", "keepAlive").
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot update machine from database")
|
|
||||||
|
|
||||||
// client has been removed from database
|
|
||||||
// since the stream opened, terminate connection.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
now := time.Now().UTC()
|
|
||||||
machine.LastSeen = &now
|
|
||||||
err = h.db.TouchMachine(machine)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Str("channel", "keepAlive").
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot update machine LastSeen")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Trace().
|
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Str("channel", "keepAlive").
|
|
||||||
Int("bytes", len(data)).
|
|
||||||
Msg("Machine updated successfully after sending keep alive")
|
|
||||||
|
|
||||||
case <-updateChan:
|
|
||||||
log.Trace().
|
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Str("channel", "update").
|
|
||||||
Msg("Received a request for update")
|
|
||||||
updateRequestsReceivedOnChannel.WithLabelValues(machine.User.Name, machine.Hostname).
|
|
||||||
Inc()
|
|
||||||
|
|
||||||
if h.db.IsOutdated(machine, h.getLastStateChange()) {
|
|
||||||
var lastUpdate time.Time
|
|
||||||
if machine.LastSuccessfulUpdate != nil {
|
|
||||||
lastUpdate = *machine.LastSuccessfulUpdate
|
|
||||||
}
|
|
||||||
log.Debug().
|
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Time("last_successful_update", lastUpdate).
|
|
||||||
Time("last_state_change", h.getLastStateChange(machine.User)).
|
|
||||||
Msgf("There has been updates since the last successful update to %s", machine.Hostname)
|
|
||||||
data, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Str("channel", "update").
|
|
||||||
Err(err).
|
|
||||||
Msg("Could not get the map update")
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_, err = writer.Write(data)
|
_, err = writer.Write(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
logErr(err, "Cannot write keep alive message")
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
return
|
||||||
Str("machine", machine.Hostname).
|
}
|
||||||
Str("channel", "update").
|
if flusher, ok := writer.(http.Flusher); ok {
|
||||||
Err(err).
|
flusher.Flush()
|
||||||
Msg("Could not write the map response")
|
} else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = h.db.TouchMachine(machine)
|
||||||
|
if err != nil {
|
||||||
|
logErr(err, "Cannot update machine LastSeen")
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-updateChan:
|
||||||
|
data, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy)
|
||||||
|
if err != nil {
|
||||||
|
logErr(err, "Could not get the map update")
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = writer.Write(data)
|
||||||
|
if err != nil {
|
||||||
|
logErr(err, "Could not write the map response")
|
||||||
|
|
||||||
updateRequestsSentToNode.WithLabelValues(machine.User.Name, machine.Hostname, "failed").
|
updateRequestsSentToNode.WithLabelValues(machine.User.Name, machine.Hostname, "failed").
|
||||||
Inc()
|
Inc()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
flusher, ok := writer.(http.Flusher)
|
if flusher, ok := writer.(http.Flusher); ok {
|
||||||
if !ok {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Str("channel", "update").
|
|
||||||
Msg("Cannot cast writer to http.Flusher")
|
|
||||||
} else {
|
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
|
} else {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Trace().
|
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Str("channel", "update").
|
|
||||||
Msg("Updated Map has been sent")
|
|
||||||
updateRequestsSentToNode.WithLabelValues(machine.User.Name, machine.Hostname, "success").
|
|
||||||
Inc()
|
|
||||||
|
|
||||||
// Keep track of the last successful update,
|
// Keep track of the last successful update,
|
||||||
// we sometimes end in a state were the update
|
// we sometimes end in a state were the update
|
||||||
// is not picked up by a client and we use this
|
// is not picked up by a client and we use this
|
||||||
// to determine if we should "force" an update.
|
// to determine if we should "force" an update.
|
||||||
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
|
|
||||||
// when an outdated machine object is kept alive, e.g. db is update from
|
|
||||||
// command line, but then overwritten.
|
|
||||||
err = h.db.UpdateMachineFromDatabase(machine)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Str("channel", "update").
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot update machine from database")
|
|
||||||
|
|
||||||
// client has been removed from database
|
|
||||||
// since the stream opened, terminate connection.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
now := time.Now().UTC()
|
|
||||||
|
|
||||||
lastStateUpdate.WithLabelValues(machine.User.Name, machine.Hostname).
|
|
||||||
Set(float64(now.Unix()))
|
|
||||||
machine.LastSuccessfulUpdate = &now
|
|
||||||
|
|
||||||
err = h.db.TouchMachine(machine)
|
err = h.db.TouchMachine(machine)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
logErr(err, "Cannot update machine LastSuccessfulUpdate")
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Str("channel", "update").
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot update machine LastSuccessfulUpdate")
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
var lastUpdate time.Time
|
|
||||||
if machine.LastSuccessfulUpdate != nil {
|
|
||||||
lastUpdate = *machine.LastSuccessfulUpdate
|
|
||||||
}
|
|
||||||
log.Trace().
|
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Time("last_successful_update", lastUpdate).
|
|
||||||
Time("last_state_change", h.getLastStateChange(machine.User)).
|
|
||||||
Msgf("%s is up to date", machine.Hostname)
|
|
||||||
}
|
|
||||||
|
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
log.Info().
|
logInfo("The client has closed the connection")
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Msg("The client has closed the connection")
|
|
||||||
// TODO: Abstract away all the database calls, this can cause race conditions
|
|
||||||
// when an outdated machine object is kept alive, e.g. db is update from
|
|
||||||
// command line, but then overwritten.
|
|
||||||
err := h.db.UpdateMachineFromDatabase(machine)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Str("channel", "Done").
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot update machine from database")
|
|
||||||
|
|
||||||
// client has been removed from database
|
err := h.db.TouchMachine(machine)
|
||||||
// since the stream opened, terminate connection.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
now := time.Now().UTC()
|
|
||||||
machine.LastSeen = &now
|
|
||||||
err = h.db.TouchMachine(machine)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
logErr(err, "Cannot update machine LastSeen")
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Str("channel", "Done").
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot update machine LastSeen")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// The connection has been closed, so we can stop polling.
|
// The connection has been closed, so we can stop polling.
|
||||||
return
|
return
|
||||||
|
|
||||||
case <-h.shutdownChan:
|
case <-h.shutdownChan:
|
||||||
log.Info().
|
logInfo("The long-poll handler is shutting down")
|
||||||
Str("handler", "PollNetMapStream").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Msg("The long-poll handler is shutting down")
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) scheduledPollWorker(
|
|
||||||
ctx context.Context,
|
|
||||||
updateChan chan struct{},
|
|
||||||
keepAliveChan chan []byte,
|
|
||||||
mapRequest tailcfg.MapRequest,
|
|
||||||
machine *types.Machine,
|
|
||||||
isNoise bool,
|
|
||||||
) {
|
|
||||||
// TODO(kradalby): This is a stepping stone, mapper should be initiated once
|
|
||||||
// per client or something similar
|
|
||||||
mapp := mapper.NewMapper(h.db,
|
|
||||||
h.privateKey2019,
|
|
||||||
isNoise,
|
|
||||||
h.DERPMap,
|
|
||||||
h.cfg.BaseDomain,
|
|
||||||
h.cfg.DNSConfig,
|
|
||||||
h.cfg.LogTail.Enabled,
|
|
||||||
h.cfg.RandomizeClientPort,
|
|
||||||
)
|
|
||||||
|
|
||||||
keepAliveTicker := time.NewTicker(keepAliveInterval)
|
|
||||||
updateCheckerTicker := time.NewTicker(h.cfg.NodeUpdateCheckInterval)
|
|
||||||
|
|
||||||
defer closeChanWithLog(
|
|
||||||
updateChan,
|
|
||||||
fmt.Sprint(ctx.Value(machineNameContextKey)),
|
|
||||||
"updateChan",
|
|
||||||
)
|
|
||||||
defer closeChanWithLog(
|
|
||||||
keepAliveChan,
|
|
||||||
fmt.Sprint(ctx.Value(machineNameContextKey)),
|
|
||||||
"keepAliveChan",
|
|
||||||
)
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
|
|
||||||
case <-keepAliveTicker.C:
|
|
||||||
data, err := mapp.CreateKeepAliveResponse(mapRequest, machine)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Str("func", "keepAlive").
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Err(err).
|
|
||||||
Msg("Error generating the keep alive msg")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().
|
|
||||||
Str("func", "keepAlive").
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Msg("Sending keepalive")
|
|
||||||
select {
|
|
||||||
case keepAliveChan <- data:
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
case <-updateCheckerTicker.C:
|
|
||||||
log.Debug().
|
|
||||||
Str("func", "scheduledPollWorker").
|
|
||||||
Str("machine", machine.Hostname).
|
|
||||||
Bool("noise", isNoise).
|
|
||||||
Msg("Sending update request")
|
|
||||||
updateRequestsFromNode.WithLabelValues(machine.User.Name, machine.Hostname, "scheduled-update").
|
|
||||||
Inc()
|
|
||||||
select {
|
|
||||||
case updateChan <- struct{}{}:
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func closeChanWithLog[C chan []byte | chan struct{}](channel C, machine, name string) {
|
func closeChanWithLog[C chan []byte | chan struct{}](channel C, machine, name string) {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMap").
|
Str("handler", "PollNetMap").
|
||||||
|
Loading…
Reference in New Issue
Block a user