Replace the timestamp based state system

This commit replaces the timestamp based state system with a new
one that has update channels directly to the connected nodes. It
will send an update to all listening clients via the polling
mechanism.

It introduces a new package notifier, which has a concurrency safe
manager for all our channels to the connected nodes.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-06-21 11:29:52 +02:00 committed by Kristoffer Dalby
parent 056d3a81c5
commit 66ff1fcd40
13 changed files with 216 additions and 731 deletions

View File

@ -10,7 +10,6 @@ import (
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
"sort"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -26,13 +25,13 @@ import (
"github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/derp" "github.com/juanfont/headscale/hscontrol/derp"
derpServer "github.com/juanfont/headscale/hscontrol/derp/server" derpServer "github.com/juanfont/headscale/hscontrol/derp/server"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
zerolog "github.com/philip-bui/grpc-zerolog" zerolog "github.com/philip-bui/grpc-zerolog"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/puzpuzpuz/xsync/v2"
zl "github.com/rs/zerolog" zl "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"golang.org/x/crypto/acme" "golang.org/x/crypto/acme"
@ -84,7 +83,7 @@ type Headscale struct {
ACLPolicy *policy.ACLPolicy ACLPolicy *policy.ACLPolicy
lastStateChange *xsync.MapOf[string, time.Time] nodeNotifier *notifier.Notifier
oidcProvider *oidc.Provider oidcProvider *oidc.Provider
oauth2Config *oauth2.Config oauth2Config *oauth2.Config
@ -93,9 +92,6 @@ type Headscale struct {
shutdownChan chan struct{} shutdownChan chan struct{}
pollNetMapStreamWG sync.WaitGroup pollNetMapStreamWG sync.WaitGroup
stateUpdateChan chan struct{}
cancelStateUpdateChan chan struct{}
} }
func NewHeadscale(cfg *types.Config) (*Headscale, error) { func NewHeadscale(cfg *types.Config) (*Headscale, error) {
@ -158,19 +154,14 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
noisePrivateKey: noisePrivateKey, noisePrivateKey: noisePrivateKey,
registrationCache: registrationCache, registrationCache: registrationCache,
pollNetMapStreamWG: sync.WaitGroup{}, pollNetMapStreamWG: sync.WaitGroup{},
lastStateChange: xsync.NewMapOf[time.Time](), nodeNotifier: notifier.NewNotifier(),
stateUpdateChan: make(chan struct{}),
cancelStateUpdateChan: make(chan struct{}),
} }
go app.watchStateChannel()
database, err := db.NewHeadscaleDatabase( database, err := db.NewHeadscaleDatabase(
cfg.DBtype, cfg.DBtype,
dbString, dbString,
app.dbDebug, app.dbDebug,
app.stateUpdateChan, app.nodeNotifier,
cfg.IPPrefixes, cfg.IPPrefixes,
cfg.BaseDomain) cfg.BaseDomain)
if err != nil { if err != nil {
@ -203,7 +194,11 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
if cfg.DERP.ServerEnabled { if cfg.DERP.ServerEnabled {
// TODO(kradalby): replace this key with a dedicated DERP key. // TODO(kradalby): replace this key with a dedicated DERP key.
embeddedDERPServer, err := derpServer.NewDERPServer(cfg.ServerURL, key.NodePrivate(*privateKey), &cfg.DERP) embeddedDERPServer, err := derpServer.NewDERPServer(
cfg.ServerURL,
key.NodePrivate(*privateKey),
&cfg.DERP,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -230,10 +225,14 @@ func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
// expireExpiredMachines expires machines that have an explicit expiry set // expireExpiredMachines expires machines that have an explicit expiry set
// after that expiry time has passed. // after that expiry time has passed.
func (h *Headscale) expireExpiredMachines(milliSeconds int64) { func (h *Headscale) expireExpiredMachines(intervalMs int64) {
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) interval := time.Duration(intervalMs) * time.Millisecond
ticker := time.NewTicker(interval)
lastCheck := time.Unix(0, 0)
for range ticker.C { for range ticker.C {
h.db.ExpireExpiredMachines(h.getLastStateChange()) lastCheck = h.db.ExpireExpiredMachines(lastCheck)
} }
} }
@ -258,7 +257,7 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
h.DERPMap.Regions[region.RegionID] = &region h.DERPMap.Regions[region.RegionID] = &region
} }
h.setLastStateChangeToNow() h.nodeNotifier.NotifyAll()
} }
} }
} }
@ -722,7 +721,7 @@ func (h *Headscale) Serve() error {
Str("path", aclPath). Str("path", aclPath).
Msg("ACL policy successfully reloaded, notifying nodes of change") Msg("ACL policy successfully reloaded, notifying nodes of change")
h.setLastStateChangeToNow() h.nodeNotifier.NotifyAll()
} }
default: default:
@ -760,10 +759,6 @@ func (h *Headscale) Serve() error {
// Stop listening (and unlink the socket if unix type): // Stop listening (and unlink the socket if unix type):
socketListener.Close() socketListener.Close()
<-h.cancelStateUpdateChan
close(h.stateUpdateChan)
close(h.cancelStateUpdateChan)
// Close db connections // Close db connections
err = h.db.Close() err = h.db.Close()
if err != nil { if err != nil {
@ -859,73 +854,6 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
} }
} }
// TODO(kradalby): baby steps, make this more robust.
func (h *Headscale) watchStateChannel() {
for {
select {
case <-h.stateUpdateChan:
h.setLastStateChangeToNow()
case <-h.cancelStateUpdateChan:
return
}
}
}
func (h *Headscale) setLastStateChangeToNow() {
var err error
now := time.Now().UTC()
users, err := h.db.ListUsers()
if err != nil {
log.Error().
Caller().
Err(err).
Msg("failed to fetch all users, failing to update last changed state.")
}
for _, user := range users {
lastStateUpdate.WithLabelValues(user.Name, "headscale").Set(float64(now.Unix()))
if h.lastStateChange == nil {
h.lastStateChange = xsync.NewMapOf[time.Time]()
}
h.lastStateChange.Store(user.Name, now)
}
}
func (h *Headscale) getLastStateChange(users ...types.User) time.Time {
times := []time.Time{}
// getLastStateChange takes a list of users as a "filter", if no users
// are past, then use the entier list of users and look for the last update
if len(users) > 0 {
for _, user := range users {
if lastChange, ok := h.lastStateChange.Load(user.Name); ok {
times = append(times, lastChange)
}
}
} else {
h.lastStateChange.Range(func(key string, value time.Time) bool {
times = append(times, value)
return true
})
}
sort.Slice(times, func(i, j int) bool {
return times[i].After(times[j])
})
log.Trace().Msgf("Latest times %#v", times)
if len(times) == 0 {
return time.Now().UTC()
} else {
return times[0]
}
}
func notFoundHandler( func notFoundHandler(
writer http.ResponseWriter, writer http.ResponseWriter,
req *http.Request, req *http.Request,

View File

@ -63,8 +63,6 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
c.Assert(len(machine1.IPAddresses), check.Equals, 1) c.Assert(len(machine1.IPAddresses), check.Equals, 1)
c.Assert(machine1.IPAddresses[0], check.Equals, expected) c.Assert(machine1.IPAddresses[0], check.Equals, expected)
c.Assert(channelUpdates, check.Equals, int32(0))
} }
func (s *Suite) TestGetMultiIp(c *check.C) { func (s *Suite) TestGetMultiIp(c *check.C) {
@ -153,8 +151,6 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
c.Assert(len(nextIP2), check.Equals, 1) c.Assert(len(nextIP2), check.Equals, 1)
c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String()) c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String())
c.Assert(channelUpdates, check.Equals, int32(0))
} }
func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) { func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
@ -192,6 +188,4 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
c.Assert(len(ips2), check.Equals, 1) c.Assert(len(ips2), check.Equals, 1)
c.Assert(ips2[0].String(), check.Equals, expected.String()) c.Assert(ips2[0].String(), check.Equals, expected.String())
c.Assert(channelUpdates, check.Equals, int32(0))
} }

View File

@ -22,8 +22,6 @@ func (*Suite) TestCreateAPIKey(c *check.C) {
keys, err := db.ListAPIKeys() keys, err := db.ListAPIKeys()
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(keys), check.Equals, 1) c.Assert(len(keys), check.Equals, 1)
c.Assert(channelUpdates, check.Equals, int32(0))
} }
func (*Suite) TestAPIKeyDoesNotExist(c *check.C) { func (*Suite) TestAPIKeyDoesNotExist(c *check.C) {
@ -41,8 +39,6 @@ func (*Suite) TestValidateAPIKeyOk(c *check.C) {
valid, err := db.ValidateAPIKey(apiKeyStr) valid, err := db.ValidateAPIKey(apiKeyStr)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(valid, check.Equals, true) c.Assert(valid, check.Equals, true)
c.Assert(channelUpdates, check.Equals, int32(0))
} }
func (*Suite) TestValidateAPIKeyNotOk(c *check.C) { func (*Suite) TestValidateAPIKeyNotOk(c *check.C) {
@ -71,8 +67,6 @@ func (*Suite) TestValidateAPIKeyNotOk(c *check.C) {
validWithErr, err := db.ValidateAPIKey("produceerrorkey") validWithErr, err := db.ValidateAPIKey("produceerrorkey")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
c.Assert(validWithErr, check.Equals, false) c.Assert(validWithErr, check.Equals, false)
c.Assert(channelUpdates, check.Equals, int32(0))
} }
func (*Suite) TestExpireAPIKey(c *check.C) { func (*Suite) TestExpireAPIKey(c *check.C) {
@ -92,6 +86,4 @@ func (*Suite) TestExpireAPIKey(c *check.C) {
notValid, err := db.ValidateAPIKey(apiKeyStr) notValid, err := db.ValidateAPIKey(apiKeyStr)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(notValid, check.Equals, false) c.Assert(notValid, check.Equals, false)
c.Assert(channelUpdates, check.Equals, int32(0))
} }

View File

@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/glebarez/sqlite" "github.com/glebarez/sqlite"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -36,8 +37,8 @@ type KV struct {
} }
type HSDatabase struct { type HSDatabase struct {
db *gorm.DB db *gorm.DB
notifyStateChan chan<- struct{} notifier *notifier.Notifier
ipAllocationMutex sync.Mutex ipAllocationMutex sync.Mutex
@ -50,7 +51,7 @@ type HSDatabase struct {
func NewHeadscaleDatabase( func NewHeadscaleDatabase(
dbType, connectionAddr string, dbType, connectionAddr string,
debug bool, debug bool,
notifyStateChan chan<- struct{}, notifier *notifier.Notifier,
ipPrefixes []netip.Prefix, ipPrefixes []netip.Prefix,
baseDomain string, baseDomain string,
) (*HSDatabase, error) { ) (*HSDatabase, error) {
@ -60,8 +61,8 @@ func NewHeadscaleDatabase(
} }
db := HSDatabase{ db := HSDatabase{
db: dbConn, db: dbConn,
notifyStateChan: notifyStateChan, notifier: notifier,
ipPrefixes: ipPrefixes, ipPrefixes: ipPrefixes,
baseDomain: baseDomain, baseDomain: baseDomain,
@ -297,10 +298,6 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) {
) )
} }
func (hsdb *HSDatabase) notifyStateChange() {
hsdb.notifyStateChan <- struct{}{}
}
// getValue returns the value for the given key in KV. // getValue returns the value for the given key in KV.
func (hsdb *HSDatabase) getValue(key string) (string, error) { func (hsdb *HSDatabase) getValue(key string) (string, error) {
var row KV var row KV

View File

@ -218,7 +218,7 @@ func (hsdb *HSDatabase) SetTags(
} }
machine.ForcedTags = newTags machine.ForcedTags = newTags
hsdb.notifyStateChange() hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
if err := hsdb.db.Save(machine).Error; err != nil { if err := hsdb.db.Save(machine).Error; err != nil {
return fmt.Errorf("failed to update tags for machine in the database: %w", err) return fmt.Errorf("failed to update tags for machine in the database: %w", err)
@ -232,7 +232,7 @@ func (hsdb *HSDatabase) ExpireMachine(machine *types.Machine) error {
now := time.Now() now := time.Now()
machine.Expiry = &now machine.Expiry = &now
hsdb.notifyStateChange() hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
if err := hsdb.db.Save(machine).Error; err != nil { if err := hsdb.db.Save(machine).Error; err != nil {
return fmt.Errorf("failed to expire machine in the database: %w", err) return fmt.Errorf("failed to expire machine in the database: %w", err)
@ -259,7 +259,7 @@ func (hsdb *HSDatabase) RenameMachine(machine *types.Machine, newName string) er
} }
machine.GivenName = newName machine.GivenName = newName
hsdb.notifyStateChange() hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
if err := hsdb.db.Save(machine).Error; err != nil { if err := hsdb.db.Save(machine).Error; err != nil {
return fmt.Errorf("failed to rename machine in the database: %w", err) return fmt.Errorf("failed to rename machine in the database: %w", err)
@ -275,7 +275,7 @@ func (hsdb *HSDatabase) RefreshMachine(machine *types.Machine, expiry time.Time)
machine.LastSuccessfulUpdate = &now machine.LastSuccessfulUpdate = &now
machine.Expiry = &expiry machine.Expiry = &expiry
hsdb.notifyStateChange() hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
if err := hsdb.db.Save(machine).Error; err != nil { if err := hsdb.db.Save(machine).Error; err != nil {
return fmt.Errorf( return fmt.Errorf(
@ -323,32 +323,6 @@ func (hsdb *HSDatabase) HardDeleteMachine(machine *types.Machine) error {
return nil return nil
} }
func (hsdb *HSDatabase) IsOutdated(machine *types.Machine, lastChange time.Time) bool {
if err := hsdb.UpdateMachineFromDatabase(machine); err != nil {
// It does not seem meaningful to propagate this error as the end result
// will have to be that the machine has to be considered outdated.
return true
}
// Get the last update from all headscale users to compare with our nodes
// last update.
// TODO(kradalby): Only request updates from users where we can talk to nodes
// This would mostly be for a bit of performance, and can be calculated based on
// ACLs.
lastUpdate := machine.CreatedAt
if machine.LastSuccessfulUpdate != nil {
lastUpdate = *machine.LastSuccessfulUpdate
}
log.Trace().
Caller().
Str("machine", machine.Hostname).
Time("last_successful_update", lastChange).
Time("last_state_change", lastUpdate).
Msgf("Checking if %s is missing updates", machine.Hostname)
return lastUpdate.Before(lastChange)
}
func (hsdb *HSDatabase) RegisterMachineFromAuthCallback( func (hsdb *HSDatabase) RegisterMachineFromAuthCallback(
cache *cache.Cache, cache *cache.Cache,
nodeKeyStr string, nodeKeyStr string,
@ -626,7 +600,7 @@ func (hsdb *HSDatabase) enableRoutes(machine *types.Machine, routeStrs ...string
} }
} }
hsdb.notifyStateChange() hsdb.notifier.NotifyWithIgnore(machine.MachineKey)
return nil return nil
} }
@ -723,17 +697,22 @@ func (hsdb *HSDatabase) ExpireEphemeralMachines(inactivityThreshhold time.Durati
} }
if expiredFound { if expiredFound {
hsdb.notifyStateChange() hsdb.notifier.NotifyAll()
} }
} }
} }
func (hsdb *HSDatabase) ExpireExpiredMachines(lastChange time.Time) { func (hsdb *HSDatabase) ExpireExpiredMachines(lastCheck time.Time) time.Time {
// use the time of the start of the function to ensure we
// dont miss some machines by returning it _after_ we have
// checked everything.
started := time.Now()
users, err := hsdb.ListUsers() users, err := hsdb.ListUsers()
if err != nil { if err != nil {
log.Error().Err(err).Msg("Error listing users") log.Error().Err(err).Msg("Error listing users")
return return time.Unix(0, 0)
} }
for _, user := range users { for _, user := range users {
@ -744,13 +723,13 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastChange time.Time) {
Str("user", user.Name). Str("user", user.Name).
Msg("Error listing machines in user") Msg("Error listing machines in user")
return return time.Unix(0, 0)
} }
expiredFound := false expiredFound := false
for index, machine := range machines { for index, machine := range machines {
if machine.IsExpired() && if machine.IsExpired() &&
machine.Expiry.After(lastChange) { machine.Expiry.After(lastCheck) {
expiredFound = true expiredFound = true
err := hsdb.ExpireMachine(&machines[index]) err := hsdb.ExpireMachine(&machines[index])
@ -770,7 +749,9 @@ func (hsdb *HSDatabase) ExpireExpiredMachines(lastChange time.Time) {
} }
if expiredFound { if expiredFound {
hsdb.notifyStateChange() hsdb.notifier.NotifyAll()
} }
} }
return started
} }

View File

@ -39,8 +39,6 @@ func (s *Suite) TestGetMachine(c *check.C) {
_, err = db.GetMachine("test", "testmachine") _, err = db.GetMachine("test", "testmachine")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(channelUpdates, check.Equals, int32(0))
} }
func (s *Suite) TestGetMachineByID(c *check.C) { func (s *Suite) TestGetMachineByID(c *check.C) {
@ -67,8 +65,6 @@ func (s *Suite) TestGetMachineByID(c *check.C) {
_, err = db.GetMachineByID(0) _, err = db.GetMachineByID(0)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(channelUpdates, check.Equals, int32(0))
} }
func (s *Suite) TestGetMachineByNodeKey(c *check.C) { func (s *Suite) TestGetMachineByNodeKey(c *check.C) {
@ -98,8 +94,6 @@ func (s *Suite) TestGetMachineByNodeKey(c *check.C) {
_, err = db.GetMachineByNodeKey(nodeKey.Public()) _, err = db.GetMachineByNodeKey(nodeKey.Public())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(channelUpdates, check.Equals, int32(0))
} }
func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) { func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) {
@ -131,8 +125,6 @@ func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) {
_, err = db.GetMachineByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public()) _, err = db.GetMachineByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(channelUpdates, check.Equals, int32(0))
} }
func (s *Suite) TestDeleteMachine(c *check.C) { func (s *Suite) TestDeleteMachine(c *check.C) {
@ -155,8 +147,6 @@ func (s *Suite) TestDeleteMachine(c *check.C) {
_, err = db.GetMachine(user.Name, "testmachine") _, err = db.GetMachine(user.Name, "testmachine")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
c.Assert(channelUpdates, check.Equals, int32(0))
} }
func (s *Suite) TestHardDeleteMachine(c *check.C) { func (s *Suite) TestHardDeleteMachine(c *check.C) {
@ -179,8 +169,6 @@ func (s *Suite) TestHardDeleteMachine(c *check.C) {
_, err = db.GetMachine(user.Name, "testmachine3") _, err = db.GetMachine(user.Name, "testmachine3")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
c.Assert(channelUpdates, check.Equals, int32(0))
} }
func (s *Suite) TestListPeers(c *check.C) { func (s *Suite) TestListPeers(c *check.C) {
@ -217,8 +205,6 @@ func (s *Suite) TestListPeers(c *check.C) {
c.Assert(peersOfMachine0[0].Hostname, check.Equals, "testmachine2") c.Assert(peersOfMachine0[0].Hostname, check.Equals, "testmachine2")
c.Assert(peersOfMachine0[5].Hostname, check.Equals, "testmachine7") c.Assert(peersOfMachine0[5].Hostname, check.Equals, "testmachine7")
c.Assert(peersOfMachine0[8].Hostname, check.Equals, "testmachine10") c.Assert(peersOfMachine0[8].Hostname, check.Equals, "testmachine10")
c.Assert(channelUpdates, check.Equals, int32(0))
} }
func (s *Suite) TestGetACLFilteredPeers(c *check.C) { func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
@ -312,8 +298,6 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
c.Assert(peersOfAdminMachine[0].Hostname, check.Equals, "testmachine2") c.Assert(peersOfAdminMachine[0].Hostname, check.Equals, "testmachine2")
c.Assert(peersOfAdminMachine[2].Hostname, check.Equals, "testmachine4") c.Assert(peersOfAdminMachine[2].Hostname, check.Equals, "testmachine4")
c.Assert(peersOfAdminMachine[5].Hostname, check.Equals, "testmachine7") c.Assert(peersOfAdminMachine[5].Hostname, check.Equals, "testmachine7")
c.Assert(channelUpdates, check.Equals, int32(0))
} }
func (s *Suite) TestExpireMachine(c *check.C) { func (s *Suite) TestExpireMachine(c *check.C) {
@ -349,8 +333,6 @@ func (s *Suite) TestExpireMachine(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(machineFromDB.IsExpired(), check.Equals, true) c.Assert(machineFromDB.IsExpired(), check.Equals, true)
c.Assert(channelUpdates, check.Equals, int32(1))
} }
func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) { func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) {
@ -372,8 +354,6 @@ func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) {
for i := range deserialized { for i := range deserialized {
c.Assert(deserialized[i], check.Equals, input[i]) c.Assert(deserialized[i], check.Equals, input[i])
} }
c.Assert(channelUpdates, check.Equals, int32(0))
} }
func (s *Suite) TestGenerateGivenName(c *check.C) { func (s *Suite) TestGenerateGivenName(c *check.C) {
@ -418,8 +398,6 @@ func (s *Suite) TestGenerateGivenName(c *check.C) {
comment = check.Commentf("Unique users, unique machines, same hostname, conflict") comment = check.Commentf("Unique users, unique machines, same hostname, conflict")
c.Assert(err, check.IsNil, comment) c.Assert(err, check.IsNil, comment)
c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment) c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", MachineGivenNameHashLength), comment)
c.Assert(channelUpdates, check.Equals, int32(0))
} }
func (s *Suite) TestSetTags(c *check.C) { func (s *Suite) TestSetTags(c *check.C) {
@ -463,8 +441,6 @@ func (s *Suite) TestSetTags(c *check.C) {
check.DeepEquals, check.DeepEquals,
types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}), types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}),
) )
c.Assert(channelUpdates, check.Equals, int32(2))
} }
func TestHeadscale_generateGivenName(t *testing.T) { func TestHeadscale_generateGivenName(t *testing.T) {
@ -655,6 +631,4 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
enabledRoutes, err := db.GetEnabledRoutes(machine0ByID) enabledRoutes, err := db.GetEnabledRoutes(machine0ByID)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(enabledRoutes, check.HasLen, 4) c.Assert(enabledRoutes, check.HasLen, 4)
c.Assert(channelUpdates, check.Equals, int32(4))
} }

View File

@ -161,8 +161,6 @@ func (*Suite) TestEphemeralKey(c *check.C) {
// The machine record should have been deleted // The machine record should have been deleted
_, err = db.GetMachine("test7", "testest") _, err = db.GetMachine("test7", "testest")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
c.Assert(channelUpdates, check.Equals, int32(1))
} }
func (*Suite) TestExpirePreauthKey(c *check.C) { func (*Suite) TestExpirePreauthKey(c *check.C) {

View File

@ -374,7 +374,7 @@ func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error {
} }
if routesChanged { if routesChanged {
hsdb.notifyStateChange() hsdb.notifier.NotifyAll()
} }
return nil return nil

View File

@ -52,8 +52,6 @@ func (s *Suite) TestGetRoutes(c *check.C) {
err = db.enableRoutes(&machine, "10.0.0.0/24") err = db.enableRoutes(&machine, "10.0.0.0/24")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(channelUpdates, check.Equals, int32(0))
} }
func (s *Suite) TestGetEnableRoutes(c *check.C) { func (s *Suite) TestGetEnableRoutes(c *check.C) {
@ -129,8 +127,6 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&machine) enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&machine)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2) c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2)
c.Assert(channelUpdates, check.Equals, int32(3))
} }
func (s *Suite) TestIsUniquePrefix(c *check.C) { func (s *Suite) TestIsUniquePrefix(c *check.C) {
@ -215,8 +211,6 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
routes, err = db.GetMachinePrimaryRoutes(&machine2) routes, err = db.GetMachinePrimaryRoutes(&machine2)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 0) c.Assert(len(routes), check.Equals, 0)
c.Assert(channelUpdates, check.Equals, int32(3))
} }
func (s *Suite) TestSubnetFailover(c *check.C) { func (s *Suite) TestSubnetFailover(c *check.C) {
@ -359,8 +353,6 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
routes, err = db.GetMachinePrimaryRoutes(&machine2) routes, err = db.GetMachinePrimaryRoutes(&machine2)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 2) c.Assert(len(routes), check.Equals, 2)
c.Assert(channelUpdates, check.Equals, int32(6))
} }
func (s *Suite) TestDeleteRoutes(c *check.C) { func (s *Suite) TestDeleteRoutes(c *check.C) {
@ -420,6 +412,4 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
enabledRoutes1, err := db.GetEnabledRoutes(&machine1) enabledRoutes1, err := db.GetEnabledRoutes(&machine1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 1) c.Assert(len(enabledRoutes1), check.Equals, 1)
c.Assert(channelUpdates, check.Equals, int32(2))
} }

View File

@ -3,9 +3,9 @@ package db
import ( import (
"net/netip" "net/netip"
"os" "os"
"sync/atomic"
"testing" "testing"
"github.com/juanfont/headscale/hscontrol/notifier"
"gopkg.in/check.v1" "gopkg.in/check.v1"
) )
@ -20,14 +20,9 @@ type Suite struct{}
var ( var (
tmpDir string tmpDir string
db *HSDatabase db *HSDatabase
// channelUpdates counts the number of times
// either of the channels was notified.
channelUpdates int32
) )
func (s *Suite) SetUpTest(c *check.C) { func (s *Suite) SetUpTest(c *check.C) {
atomic.StoreInt32(&channelUpdates, 0)
s.ResetDB(c) s.ResetDB(c)
} }
@ -35,13 +30,6 @@ func (s *Suite) TearDownTest(c *check.C) {
os.RemoveAll(tmpDir) os.RemoveAll(tmpDir)
} }
func notificationSink(c <-chan struct{}) {
for {
<-c
atomic.AddInt32(&channelUpdates, 1)
}
}
func (s *Suite) ResetDB(c *check.C) { func (s *Suite) ResetDB(c *check.C) {
if len(tmpDir) != 0 { if len(tmpDir) != 0 {
os.RemoveAll(tmpDir) os.RemoveAll(tmpDir)
@ -52,15 +40,11 @@ func (s *Suite) ResetDB(c *check.C) {
c.Fatal(err) c.Fatal(err)
} }
sink := make(chan struct{})
go notificationSink(sink)
db, err = NewHeadscaleDatabase( db, err = NewHeadscaleDatabase(
"sqlite3", "sqlite3",
tmpDir+"/headscale_test.db", tmpDir+"/headscale_test.db",
false, false,
sink, notifier.NewNotifier(),
[]netip.Prefix{ []netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"), netip.MustParsePrefix("10.27.0.0/23"),
}, },

View File

@ -10,12 +10,6 @@ const prometheusNamespace = "headscale"
var ( var (
// This is a high cardinality metric (user x machines), we might want to make this // This is a high cardinality metric (user x machines), we might want to make this
// configurable/opt-in in the future. // configurable/opt-in in the future.
lastStateUpdate = promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: prometheusNamespace,
Name: "last_update_seconds",
Help: "Time stamp in unix time when a machine or headscale was updated",
}, []string{"user", "machine"})
machineRegistrations = promauto.NewCounterVec(prometheus.CounterOpts{ machineRegistrations = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace, Namespace: prometheusNamespace,
Name: "machine_registrations_total", Name: "machine_registrations_total",
@ -33,9 +27,4 @@ var (
Help: "The number of calls/messages issued on a specific nodes update channel", Help: "The number of calls/messages issued on a specific nodes update channel",
}, []string{"user", "machine", "status"}) }, []string{"user", "machine", "status"})
// TODO(kradalby): This is very debugging, we might want to remove it. // TODO(kradalby): This is very debugging, we might want to remove it.
updateRequestsReceivedOnChannel = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "update_request_received_on_channel_total",
Help: "The number of update requests received on an update channel",
}, []string{"user", "machine"})
) )

View File

@ -0,0 +1,55 @@
package notifier
import (
"sync"
"github.com/juanfont/headscale/hscontrol/util"
)
type Notifier struct {
l sync.RWMutex
nodes map[string]chan<- struct{}
}
func NewNotifier() *Notifier {
return &Notifier{}
}
func (n *Notifier) AddNode(machineKey string, c chan<- struct{}) {
n.l.Lock()
defer n.l.Unlock()
if n.nodes == nil {
n.nodes = make(map[string]chan<- struct{})
}
n.nodes[machineKey] = c
}
func (n *Notifier) RemoveNode(machineKey string) {
n.l.Lock()
defer n.l.Unlock()
if n.nodes == nil {
return
}
delete(n.nodes, machineKey)
}
func (n *Notifier) NotifyAll() {
n.NotifyWithIgnore()
}
func (n *Notifier) NotifyWithIgnore(ignore ...string) {
n.l.RLock()
defer n.l.RUnlock()
for key, c := range n.nodes {
if util.IsStringInSlice(ignore, key) {
continue
}
c <- struct{}{}
}
}

View File

@ -21,6 +21,38 @@ type contextKey string
const machineNameContextKey = contextKey("machineName") const machineNameContextKey = contextKey("machineName")
type UpdateNode func()
func logPollFunc(
mapRequest tailcfg.MapRequest,
machine *types.Machine,
isNoise bool,
) (func(string), func(error, string)) {
return func(msg string) {
log.Info().
Caller().
Bool("noise", isNoise).
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Str("node_key", machine.NodeKey).
Str("machine", machine.Hostname).
Msg(msg)
},
func(err error, msg string) {
log.Error().
Caller().
Bool("noise", isNoise).
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Str("node_key", machine.NodeKey).
Str("machine", machine.Hostname).
Err(err).
Msg(msg)
}
}
// handlePoll is the common code for the legacy and Noise protocols to // handlePoll is the common code for the legacy and Noise protocols to
// managed the poll loop. // managed the poll loop.
func (h *Headscale) handlePoll( func (h *Headscale) handlePoll(
@ -30,6 +62,10 @@ func (h *Headscale) handlePoll(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
isNoise bool, isNoise bool,
) { ) {
logInfo, logErr := logPollFunc(mapRequest, machine, isNoise)
// TODO(kradalby): This is a stepping stone, mapper should be initiated once
// per client or something similar
mapp := mapper.NewMapper( mapp := mapper.NewMapper(
h.db, h.db,
h.privateKey2019, h.privateKey2019,
@ -48,11 +84,7 @@ func (h *Headscale) handlePoll(
err := h.db.ProcessMachineRoutes(machine) err := h.db.ProcessMachineRoutes(machine)
if err != nil { if err != nil {
log.Error(). logErr(err, "Error processing machine routes")
Caller().
Err(err).
Str("machine", machine.Hostname).
Msg("Error processing machine routes")
} }
// update ACLRules with peer informations (to update server tags if necessary) // update ACLRules with peer informations (to update server tags if necessary)
@ -60,12 +92,7 @@ func (h *Headscale) handlePoll(
// update routes with peer information // update routes with peer information
err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, machine) err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, machine)
if err != nil { if err != nil {
log.Error(). logErr(err, "Error running auto approved routes")
Caller().
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Err(err).
Msg("Error running auto approved routes")
} }
} }
@ -83,13 +110,7 @@ func (h *Headscale) handlePoll(
} }
if err := h.db.MachineSave(machine); err != nil { if err := h.db.MachineSave(machine); err != nil {
log.Error(). logErr(err, "Failed to persist/update machine in the database")
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("node_key", machine.NodeKey).
Str("machine", machine.Hostname).
Err(err).
Msg("Failed to persist/update machine in the database")
http.Error(writer, "", http.StatusInternalServerError) http.Error(writer, "", http.StatusInternalServerError)
return return
@ -97,13 +118,7 @@ func (h *Headscale) handlePoll(
mapResp, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy) mapResp, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy)
if err != nil { if err != nil {
log.Error(). logErr(err, "Failed to create MapResponse")
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("node_key", machine.NodeKey).
Str("machine", machine.Hostname).
Err(err).
Msg("Failed to get Map response")
http.Error(writer, "", http.StatusInternalServerError) http.Error(writer, "", http.StatusInternalServerError)
return return
@ -114,30 +129,16 @@ func (h *Headscale) handlePoll(
// empty endpoints to peers) // empty endpoints to peers)
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696 // Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
log.Debug(). logInfo("Client map request processed")
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Msg("Client map request processed")
if mapRequest.ReadOnly { if mapRequest.ReadOnly {
log.Info(). logInfo("Client is starting up. Probably interested in a DERP map")
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Client is starting up. Probably interested in a DERP map")
writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
_, err := writer.Write(mapResp) _, err := writer.Write(mapResp)
if err != nil { if err != nil {
log.Error(). logErr(err, "Failed to write response")
Caller().
Err(err).
Msg("Failed to write response")
} }
if f, ok := writer.(http.Flusher); ok { if f, ok := writer.(http.Flusher); ok {
@ -147,48 +148,22 @@ func (h *Headscale) handlePoll(
return return
} }
// There has been an update to _any_ of the nodes that the other nodes would
// need to know about
h.setLastStateChangeToNow()
// The request is not ReadOnly, so we need to set up channels for updating
// peers via longpoll
// Only create update channel if it has not been created
log.Trace().
Caller().
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Loading or creating update channel")
const chanSize = 8
updateChan := make(chan struct{}, chanSize)
pollDataChan := make(chan []byte, chanSize)
defer closeChanWithLog(pollDataChan, machine.Hostname, "pollDataChan")
keepAliveChan := make(chan []byte)
if mapRequest.OmitPeers && !mapRequest.Stream { if mapRequest.OmitPeers && !mapRequest.Stream {
log.Info(). logInfo("Client sent endpoint update and is ok with a response without peer list")
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Client sent endpoint update and is ok with a response without peer list")
writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
_, err := writer.Write(mapResp) _, err := writer.Write(mapResp)
if err != nil { if err != nil {
log.Error(). logErr(err, "Failed to write response")
Caller().
Err(err).
Msg("Failed to write response")
} }
// It sounds like we should update the nodes when we have received a endpoint update // It sounds like we should update the nodes when we have received a endpoint update
// even tho the comments in the tailscale code dont explicitly say so. // even tho the comments in the tailscale code dont explicitly say so.
updateRequestsFromNode.WithLabelValues(machine.User.Name, machine.Hostname, "endpoint-update"). updateRequestsFromNode.WithLabelValues(machine.User.Name, machine.Hostname, "endpoint-update").
Inc() Inc()
updateChan <- struct{}{}
// Tell all the other nodes about the new endpoint, but dont update ourselves.
h.nodeNotifier.NotifyWithIgnore(machine.MachineKey)
return return
} else if mapRequest.OmitPeers && mapRequest.Stream { } else if mapRequest.OmitPeers && mapRequest.Stream {
@ -202,43 +177,32 @@ func (h *Headscale) handlePoll(
return return
} }
log.Info(). logInfo("Sending initial map")
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Client is ready to access the tailnet")
log.Info().
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Sending initial map")
pollDataChan <- mapResp
log.Info(). // Send the client an update to make sure we send an initial mapresponse
Str("handler", "PollNetMap"). _, err = writer.Write(mapResp)
Bool("noise", isNoise). if err != nil {
Str("machine", machine.Hostname). logErr(err, "Could not write the map response")
Msg("Notifying peers")
updateRequestsFromNode.WithLabelValues(machine.User.Name, machine.Hostname, "full-update"). return
Inc() }
updateChan <- struct{}{}
if flusher, ok := writer.(http.Flusher); ok {
flusher.Flush()
} else {
return
}
h.pollNetMapStream( h.pollNetMapStream(
writer, writer,
ctx, ctx,
machine, machine,
mapp,
mapRequest, mapRequest,
pollDataChan,
keepAliveChan,
updateChan,
isNoise, isNoise,
) )
log.Trace(). logInfo("Finished stream, closing PollNetMap session")
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Finished stream, closing PollNetMap session")
} }
// pollNetMapStream stream logic for /machine/map, // pollNetMapStream stream logic for /machine/map,
@ -247,23 +211,16 @@ func (h *Headscale) pollNetMapStream(
writer http.ResponseWriter, writer http.ResponseWriter,
ctxReq context.Context, ctxReq context.Context,
machine *types.Machine, machine *types.Machine,
mapp *mapper.Mapper,
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
pollDataChan chan []byte,
keepAliveChan chan []byte,
updateChan chan struct{},
isNoise bool, isNoise bool,
) { ) {
// TODO(kradalby): This is a stepping stone, mapper should be initiated once logInfo, logErr := logPollFunc(mapRequest, machine, isNoise)
// per client or something similar
mapp := mapper.NewMapper(h.db, keepAliveTicker := time.NewTicker(keepAliveInterval)
h.privateKey2019,
isNoise, const chanSize = 8
h.DERPMap, updateChan := make(chan struct{}, chanSize)
h.cfg.BaseDomain,
h.cfg.DNSConfig,
h.cfg.LogTail.Enabled,
h.cfg.RandomizeClientPort,
)
h.pollNetMapStreamWG.Add(1) h.pollNetMapStreamWG.Add(1)
defer h.pollNetMapStreamWG.Done() defer h.pollNetMapStreamWG.Done()
@ -273,447 +230,93 @@ func (h *Headscale) pollNetMapStream(
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
go h.scheduledPollWorker( // Register the node's update channel
ctx, h.nodeNotifier.AddNode(machine.MachineKey, updateChan)
updateChan, defer h.nodeNotifier.RemoveNode(machine.MachineKey)
keepAliveChan, defer closeChanWithLog(updateChan, machine.Hostname, "updateChan")
mapRequest,
machine,
isNoise,
)
log.Trace().
Str("handler", "pollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Waiting for data to stream...")
log.Trace().
Str("handler", "pollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan)
for { for {
select { select {
case data := <-pollDataChan: case <-keepAliveTicker.C:
log.Trace(). data, err := mapp.CreateKeepAliveResponse(mapRequest, machine)
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Sending data received via pollData channel")
_, err := writer.Write(data)
if err != nil { if err != nil {
log.Error(). logErr(err, "Error generating the keep alive msg")
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "pollData").
Err(err).
Msg("Cannot write data")
return return
} }
_, err = writer.Write(data)
if err != nil {
logErr(err, "Cannot write keep alive message")
flusher, ok := writer.(http.Flusher) return
if !ok { }
log.Error(). if flusher, ok := writer.(http.Flusher); ok {
Caller().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "pollData").
Msg("Cannot cast writer to http.Flusher")
} else {
flusher.Flush() flusher.Flush()
} } else {
log.Trace().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Data from pollData channel written successfully")
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.db.UpdateMachineFromDatabase(machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "pollData").
Err(err).
Msg("Cannot update machine from database")
// client has been removed from database
// since the stream opened, terminate connection.
return return
} }
now := time.Now().UTC()
machine.LastSeen = &now
lastStateUpdate.WithLabelValues(machine.User.Name, machine.Hostname).
Set(float64(now.Unix()))
machine.LastSuccessfulUpdate = &now
err = h.db.TouchMachine(machine) err = h.db.TouchMachine(machine)
if err != nil { if err != nil {
log.Error(). logErr(err, "Cannot update machine LastSeen")
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "pollData").
Err(err).
Msg("Cannot update machine LastSuccessfulUpdate")
return return
} }
log.Trace().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Machine entry in database updated successfully after sending data")
case data := <-keepAliveChan:
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Sending keep alive message")
_, err := writer.Write(data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Err(err).
Msg("Cannot write keep alive message")
return
}
flusher, ok := writer.(http.Flusher)
if !ok {
log.Error().
Caller().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Msg("Cannot cast writer to http.Flusher")
} else {
flusher.Flush()
}
log.Trace().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Keep alive sent successfully")
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.db.UpdateMachineFromDatabase(machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Err(err).
Msg("Cannot update machine from database")
// client has been removed from database
// since the stream opened, terminate connection.
return
}
now := time.Now().UTC()
machine.LastSeen = &now
err = h.db.TouchMachine(machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Err(err).
Msg("Cannot update machine LastSeen")
return
}
log.Trace().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Machine updated successfully after sending keep alive")
case <-updateChan: case <-updateChan:
log.Trace(). data, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy)
Str("handler", "PollNetMapStream"). if err != nil {
Bool("noise", isNoise). logErr(err, "Could not get the map update")
Str("machine", machine.Hostname).
Str("channel", "update").
Msg("Received a request for update")
updateRequestsReceivedOnChannel.WithLabelValues(machine.User.Name, machine.Hostname).
Inc()
if h.db.IsOutdated(machine, h.getLastStateChange()) { return
var lastUpdate time.Time }
if machine.LastSuccessfulUpdate != nil {
lastUpdate = *machine.LastSuccessfulUpdate
}
log.Debug().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Time("last_successful_update", lastUpdate).
Time("last_state_change", h.getLastStateChange(machine.User)).
Msgf("There has been updates since the last successful update to %s", machine.Hostname)
data, err := mapp.CreateMapResponse(mapRequest, machine, h.ACLPolicy)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "update").
Err(err).
Msg("Could not get the map update")
return _, err = writer.Write(data)
} if err != nil {
_, err = writer.Write(data) logErr(err, "Could not write the map response")
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "update").
Err(err).
Msg("Could not write the map response")
updateRequestsSentToNode.WithLabelValues(machine.User.Name, machine.Hostname, "failed").
Inc()
return updateRequestsSentToNode.WithLabelValues(machine.User.Name, machine.Hostname, "failed").
}
flusher, ok := writer.(http.Flusher)
if !ok {
log.Error().
Caller().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "update").
Msg("Cannot cast writer to http.Flusher")
} else {
flusher.Flush()
}
log.Trace().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "update").
Msg("Updated Map has been sent")
updateRequestsSentToNode.WithLabelValues(machine.User.Name, machine.Hostname, "success").
Inc() Inc()
// Keep track of the last successful update, return
// we sometimes end in a state were the update }
// is not picked up by a client and we use this
// to determine if we should "force" an update.
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.db.UpdateMachineFromDatabase(machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "update").
Err(err).
Msg("Cannot update machine from database")
// client has been removed from database if flusher, ok := writer.(http.Flusher); ok {
// since the stream opened, terminate connection. flusher.Flush()
return
}
now := time.Now().UTC()
lastStateUpdate.WithLabelValues(machine.User.Name, machine.Hostname).
Set(float64(now.Unix()))
machine.LastSuccessfulUpdate = &now
err = h.db.TouchMachine(machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "update").
Err(err).
Msg("Cannot update machine LastSuccessfulUpdate")
return
}
} else { } else {
var lastUpdate time.Time return
if machine.LastSuccessfulUpdate != nil { }
lastUpdate = *machine.LastSuccessfulUpdate
} // Keep track of the last successful update,
log.Trace(). // we sometimes end in a state were the update
Str("handler", "PollNetMapStream"). // is not picked up by a client and we use this
Bool("noise", isNoise). // to determine if we should "force" an update.
Str("machine", machine.Hostname). err = h.db.TouchMachine(machine)
Time("last_successful_update", lastUpdate). if err != nil {
Time("last_state_change", h.getLastStateChange(machine.User)). logErr(err, "Cannot update machine LastSuccessfulUpdate")
Msgf("%s is up to date", machine.Hostname)
return
} }
case <-ctx.Done(): case <-ctx.Done():
log.Info(). logInfo("The client has closed the connection")
Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname).
Msg("The client has closed the connection")
// TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err := h.db.UpdateMachineFromDatabase(machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "Done").
Err(err).
Msg("Cannot update machine from database")
// client has been removed from database err := h.db.TouchMachine(machine)
// since the stream opened, terminate connection.
return
}
now := time.Now().UTC()
machine.LastSeen = &now
err = h.db.TouchMachine(machine)
if err != nil { if err != nil {
log.Error(). logErr(err, "Cannot update machine LastSeen")
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "Done").
Err(err).
Msg("Cannot update machine LastSeen")
} }
// The connection has been closed, so we can stop polling. // The connection has been closed, so we can stop polling.
return return
case <-h.shutdownChan: case <-h.shutdownChan:
log.Info(). logInfo("The long-poll handler is shutting down")
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("The long-poll handler is shutting down")
return return
} }
} }
} }
func (h *Headscale) scheduledPollWorker(
ctx context.Context,
updateChan chan struct{},
keepAliveChan chan []byte,
mapRequest tailcfg.MapRequest,
machine *types.Machine,
isNoise bool,
) {
// TODO(kradalby): This is a stepping stone, mapper should be initiated once
// per client or something similar
mapp := mapper.NewMapper(h.db,
h.privateKey2019,
isNoise,
h.DERPMap,
h.cfg.BaseDomain,
h.cfg.DNSConfig,
h.cfg.LogTail.Enabled,
h.cfg.RandomizeClientPort,
)
keepAliveTicker := time.NewTicker(keepAliveInterval)
updateCheckerTicker := time.NewTicker(h.cfg.NodeUpdateCheckInterval)
defer closeChanWithLog(
updateChan,
fmt.Sprint(ctx.Value(machineNameContextKey)),
"updateChan",
)
defer closeChanWithLog(
keepAliveChan,
fmt.Sprint(ctx.Value(machineNameContextKey)),
"keepAliveChan",
)
for {
select {
case <-ctx.Done():
return
case <-keepAliveTicker.C:
data, err := mapp.CreateKeepAliveResponse(mapRequest, machine)
if err != nil {
log.Error().
Str("func", "keepAlive").
Bool("noise", isNoise).
Err(err).
Msg("Error generating the keep alive msg")
return
}
log.Debug().
Str("func", "keepAlive").
Str("machine", machine.Hostname).
Bool("noise", isNoise).
Msg("Sending keepalive")
select {
case keepAliveChan <- data:
case <-ctx.Done():
return
}
case <-updateCheckerTicker.C:
log.Debug().
Str("func", "scheduledPollWorker").
Str("machine", machine.Hostname).
Bool("noise", isNoise).
Msg("Sending update request")
updateRequestsFromNode.WithLabelValues(machine.User.Name, machine.Hostname, "scheduled-update").
Inc()
select {
case updateChan <- struct{}{}:
case <-ctx.Done():
return
}
}
}
}
func closeChanWithLog[C chan []byte | chan struct{}](channel C, machine, name string) { func closeChanWithLog[C chan []byte | chan struct{}](channel C, machine, name string) {
log.Trace(). log.Trace().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").