mirror of
https://github.com/juanfont/headscale.git
synced 2025-08-11 12:37:42 +00:00
mapper: produce map before poll (#2628)
This commit is contained in:
130
hscontrol/app.go
130
hscontrol/app.go
@@ -28,14 +28,15 @@ import (
|
||||
derpServer "github.com/juanfont/headscale/hscontrol/derp/server"
|
||||
"github.com/juanfont/headscale/hscontrol/dns"
|
||||
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
zerolog "github.com/philip-bui/grpc-zerolog"
|
||||
"github.com/pkg/profile"
|
||||
zl "github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sasha-s/go-deadlock"
|
||||
"golang.org/x/crypto/acme"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
"golang.org/x/sync/errgroup"
|
||||
@@ -64,6 +65,19 @@ var (
|
||||
)
|
||||
)
|
||||
|
||||
var (
|
||||
debugDeadlock = envknob.Bool("HEADSCALE_DEBUG_DEADLOCK")
|
||||
debugDeadlockTimeout = envknob.RegisterDuration("HEADSCALE_DEBUG_DEADLOCK_TIMEOUT")
|
||||
)
|
||||
|
||||
func init() {
|
||||
deadlock.Opts.Disable = !debugDeadlock
|
||||
if debugDeadlock {
|
||||
deadlock.Opts.DeadlockTimeout = debugDeadlockTimeout()
|
||||
deadlock.Opts.PrintAllCurrentGoroutines = true
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
AuthPrefix = "Bearer "
|
||||
updateInterval = 5 * time.Second
|
||||
@@ -82,9 +96,8 @@ type Headscale struct {
|
||||
|
||||
// Things that generate changes
|
||||
extraRecordMan *dns.ExtraRecordsMan
|
||||
mapper *mapper.Mapper
|
||||
nodeNotifier *notifier.Notifier
|
||||
authProvider AuthProvider
|
||||
mapBatcher mapper.Batcher
|
||||
|
||||
pollNetMapStreamWG sync.WaitGroup
|
||||
}
|
||||
@@ -118,7 +131,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
cfg: cfg,
|
||||
noisePrivateKey: noisePrivateKey,
|
||||
pollNetMapStreamWG: sync.WaitGroup{},
|
||||
nodeNotifier: notifier.NewNotifier(cfg),
|
||||
state: s,
|
||||
}
|
||||
|
||||
@@ -136,12 +148,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
return
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "ephemeral-gc-policy", node.Hostname)
|
||||
app.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
app.Change(policyChanged)
|
||||
log.Debug().Uint64("node.id", ni.Uint64()).Msgf("deleted ephemeral node")
|
||||
})
|
||||
app.ephemeralGC = ephemeralGC
|
||||
@@ -153,10 +160,9 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
defer cancel()
|
||||
oidcProvider, err := NewAuthProviderOIDC(
|
||||
ctx,
|
||||
&app,
|
||||
cfg.ServerURL,
|
||||
&cfg.OIDC,
|
||||
app.state,
|
||||
app.nodeNotifier,
|
||||
)
|
||||
if err != nil {
|
||||
if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
|
||||
@@ -262,16 +268,18 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
||||
return
|
||||
|
||||
case <-expireTicker.C:
|
||||
var update types.StateUpdate
|
||||
var expiredNodeChanges []change.ChangeSet
|
||||
var changed bool
|
||||
|
||||
lastExpiryCheck, update, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
|
||||
lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
|
||||
|
||||
if changed {
|
||||
log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes")
|
||||
log.Trace().Interface("changes", expiredNodeChanges).Msgf("expiring nodes")
|
||||
|
||||
ctx := types.NotifyCtx(context.Background(), "expire-expired", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, update)
|
||||
// Send the changes directly since they're already in the new format
|
||||
for _, nodeChange := range expiredNodeChanges {
|
||||
h.Change(nodeChange)
|
||||
}
|
||||
}
|
||||
|
||||
case <-derpTickerChan:
|
||||
@@ -282,11 +290,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
||||
derpMap.Regions[region.RegionID] = ®ion
|
||||
}
|
||||
|
||||
ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
||||
Type: types.StateDERPUpdated,
|
||||
DERPMap: derpMap,
|
||||
})
|
||||
h.Change(change.DERPSet)
|
||||
|
||||
case records, ok := <-extraRecordsUpdate:
|
||||
if !ok {
|
||||
@@ -294,19 +298,16 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
||||
}
|
||||
h.cfg.TailcfgDNSConfig.ExtraRecords = records
|
||||
|
||||
ctx := types.NotifyCtx(context.Background(), "dns-extrarecord", "all")
|
||||
// TODO(kradalby): We can probably do better than sending a full update here,
|
||||
// but for now this will ensure that all of the nodes get the new records.
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
h.Change(change.ExtraRecordsSet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
||||
req interface{},
|
||||
req any,
|
||||
info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler,
|
||||
) (interface{}, error) {
|
||||
) (any, error) {
|
||||
// Check if the request is coming from the on-server client.
|
||||
// This is not secure, but it is to maintain maintainability
|
||||
// with the "legacy" database-based client
|
||||
@@ -484,58 +485,6 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
||||
return router
|
||||
}
|
||||
|
||||
// // TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
|
||||
// // Maybe we should attempt a new in memory state and not go via the DB?
|
||||
// // Maybe this should be implemented as an event bus?
|
||||
// // A bool is returned indicating if a full update was sent to all nodes
|
||||
// func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
|
||||
// users, err := db.ListUsers()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
// changed, err := polMan.SetUsers(users)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
// if changed {
|
||||
// ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all")
|
||||
// notif.NotifyAll(ctx, types.UpdateFull())
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// // TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
|
||||
// // Maybe we should attempt a new in memory state and not go via the DB?
|
||||
// // Maybe this should be implemented as an event bus?
|
||||
// // A bool is returned indicating if a full update was sent to all nodes
|
||||
// func nodesChangedHook(
|
||||
// db *db.HSDatabase,
|
||||
// polMan policy.PolicyManager,
|
||||
// notif *notifier.Notifier,
|
||||
// ) (bool, error) {
|
||||
// nodes, err := db.ListNodes()
|
||||
// if err != nil {
|
||||
// return false, err
|
||||
// }
|
||||
|
||||
// filterChanged, err := polMan.SetNodes(nodes)
|
||||
// if err != nil {
|
||||
// return false, err
|
||||
// }
|
||||
|
||||
// if filterChanged {
|
||||
// ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all")
|
||||
// notif.NotifyAll(ctx, types.UpdateFull())
|
||||
|
||||
// return true, nil
|
||||
// }
|
||||
|
||||
// return false, nil
|
||||
// }
|
||||
|
||||
// Serve launches the HTTP and gRPC server service Headscale and the API.
|
||||
func (h *Headscale) Serve() error {
|
||||
capver.CanOldCodeBeCleanedUp()
|
||||
@@ -562,8 +511,9 @@ func (h *Headscale) Serve() error {
|
||||
Str("minimum_version", capver.TailscaleVersion(capver.MinSupportedCapabilityVersion)).
|
||||
Msg("Clients with a lower minimum version will be rejected")
|
||||
|
||||
// Fetch an initial DERP Map before we start serving
|
||||
h.mapper = mapper.NewMapper(h.state, h.cfg, h.nodeNotifier)
|
||||
h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state)
|
||||
h.mapBatcher.Start()
|
||||
defer h.mapBatcher.Close()
|
||||
|
||||
// TODO(kradalby): fix state part.
|
||||
if h.cfg.DERP.ServerEnabled {
|
||||
@@ -838,8 +788,12 @@ func (h *Headscale) Serve() error {
|
||||
log.Info().
|
||||
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
||||
|
||||
ctx := types.NotifyCtx(context.Background(), "acl-sighup", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
err = h.state.AutoApproveNodes()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to approve routes after new policy")
|
||||
}
|
||||
|
||||
h.Change(change.PolicySet)
|
||||
}
|
||||
default:
|
||||
info := func(msg string) { log.Info().Msg(msg) }
|
||||
@@ -865,7 +819,6 @@ func (h *Headscale) Serve() error {
|
||||
}
|
||||
|
||||
info("closing node notifier")
|
||||
h.nodeNotifier.Close()
|
||||
|
||||
info("waiting for netmap stream to close")
|
||||
h.pollNetMapStreamWG.Wait()
|
||||
@@ -1047,3 +1000,10 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
||||
|
||||
return &machineKey, nil
|
||||
}
|
||||
|
||||
// Change is used to send changes to nodes.
|
||||
// All change should be enqueued here and empty will be automatically
|
||||
// ignored.
|
||||
func (h *Headscale) Change(c change.ChangeSet) {
|
||||
h.mapBatcher.AddWork(c)
|
||||
}
|
||||
|
@@ -10,6 +10,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
@@ -32,6 +34,21 @@ func (h *Headscale) handleRegister(
|
||||
}
|
||||
|
||||
if node != nil {
|
||||
// If an existing node is trying to register with an auth key,
|
||||
// we need to validate the auth key even for existing nodes
|
||||
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
|
||||
resp, err := h.handleRegisterWithAuthKey(regReq, machineKey)
|
||||
if err != nil {
|
||||
// Preserve HTTPError types so they can be handled properly by the HTTP layer
|
||||
var httpErr HTTPError
|
||||
if errors.As(err, &httpErr) {
|
||||
return nil, httpErr
|
||||
}
|
||||
return nil, fmt.Errorf("handling register with auth key for existing node: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
resp, err := h.handleExistingNode(node, regReq, machineKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("handling existing node: %w", err)
|
||||
@@ -47,6 +64,11 @@ func (h *Headscale) handleRegister(
|
||||
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
|
||||
resp, err := h.handleRegisterWithAuthKey(regReq, machineKey)
|
||||
if err != nil {
|
||||
// Preserve HTTPError types so they can be handled properly by the HTTP layer
|
||||
var httpErr HTTPError
|
||||
if errors.As(err, &httpErr) {
|
||||
return nil, httpErr
|
||||
}
|
||||
return nil, fmt.Errorf("handling register with auth key: %w", err)
|
||||
}
|
||||
|
||||
@@ -66,11 +88,13 @@ func (h *Headscale) handleExistingNode(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
|
||||
if node.MachineKey != machineKey {
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil)
|
||||
}
|
||||
|
||||
expired := node.IsExpired()
|
||||
|
||||
if !expired && !regReq.Expiry.IsZero() {
|
||||
requestExpiry := regReq.Expiry
|
||||
|
||||
@@ -82,42 +106,26 @@ func (h *Headscale) handleExistingNode(
|
||||
// If the request expiry is in the past, we consider it a logout.
|
||||
if requestExpiry.Before(time.Now()) {
|
||||
if node.IsEphemeral() {
|
||||
policyChanged, err := h.state.DeleteNode(node)
|
||||
c, err := h.state.DeleteNode(node)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("deleting ephemeral node: %w", err)
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "auth-logout-ephemeral-policy", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
} else {
|
||||
ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
|
||||
}
|
||||
h.Change(c)
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
n, policyChanged, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
|
||||
_, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting node expiry: %w", err)
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "auth-expiry-policy", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
} else {
|
||||
ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
|
||||
h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, requestExpiry), node.ID)
|
||||
h.Change(c)
|
||||
}
|
||||
|
||||
return nodeToRegisterResponse(n), nil
|
||||
}
|
||||
|
||||
return nodeToRegisterResponse(node), nil
|
||||
return nodeToRegisterResponse(node), nil
|
||||
}
|
||||
|
||||
func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
|
||||
@@ -168,7 +176,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
node, changed, err := h.state.HandleNodeFromPreAuthKey(
|
||||
node, changed, policyChanged, err := h.state.HandleNodeFromPreAuthKey(
|
||||
regReq,
|
||||
machineKey,
|
||||
)
|
||||
@@ -184,6 +192,12 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If node is nil, it means an ephemeral node was deleted during logout
|
||||
if node == nil {
|
||||
h.Change(changed)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// This is a bit of a back and forth, but we have a bit of a chicken and egg
|
||||
// dependency here.
|
||||
// Because the way the policy manager works, we need to have the node
|
||||
@@ -195,23 +209,22 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
// ensure we send an update.
|
||||
// This works, but might be another good candidate for doing some sort of
|
||||
// eventbus.
|
||||
routesChanged := h.state.AutoApproveRoutes(node)
|
||||
// TODO(kradalby): This needs to be ran as part of the batcher maybe?
|
||||
// now since we dont update the node/pol here anymore
|
||||
routeChange := h.state.AutoApproveRoutes(node)
|
||||
if _, _, err := h.state.SaveNode(node); err != nil {
|
||||
return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
|
||||
}
|
||||
|
||||
if routesChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "node updated", node.Hostname)
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID))
|
||||
} else if changed {
|
||||
ctx := types.NotifyCtx(context.Background(), "node created", node.Hostname)
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
} else {
|
||||
// Existing node re-registering without route changes
|
||||
// Still need to notify peers about the node being active again
|
||||
// Use UpdateFull to ensure all peers get complete peer maps
|
||||
ctx := types.NotifyCtx(context.Background(), "node re-registered", node.Hostname)
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
if routeChange && changed.Empty() {
|
||||
changed = change.NodeAdded(node.ID)
|
||||
}
|
||||
h.Change(changed)
|
||||
|
||||
// If policy changed due to node registration, send a separate policy change
|
||||
if policyChanged {
|
||||
policyChange := change.PolicyChange()
|
||||
h.Change(policyChange)
|
||||
}
|
||||
|
||||
return &tailcfg.RegisterResponse{
|
||||
|
@@ -1,5 +1,7 @@
|
||||
package capver
|
||||
|
||||
//go:generate go run ../../tools/capver/main.go
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"sort"
|
||||
@@ -10,7 +12,7 @@ import (
|
||||
"tailscale.com/util/set"
|
||||
)
|
||||
|
||||
const MinSupportedCapabilityVersion tailcfg.CapabilityVersion = 88
|
||||
const MinSupportedCapabilityVersion tailcfg.CapabilityVersion = 90
|
||||
|
||||
// CanOldCodeBeCleanedUp is intended to be called on startup to see if
|
||||
// there are old code that can ble cleaned up, entries should contain
|
||||
|
@@ -1,14 +1,10 @@
|
||||
package capver
|
||||
|
||||
// Generated DO NOT EDIT
|
||||
//Generated DO NOT EDIT
|
||||
|
||||
import "tailscale.com/tailcfg"
|
||||
|
||||
var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
|
||||
"v1.60.0": 87,
|
||||
"v1.60.1": 87,
|
||||
"v1.62.0": 88,
|
||||
"v1.62.1": 88,
|
||||
"v1.64.0": 90,
|
||||
"v1.64.1": 90,
|
||||
"v1.64.2": 90,
|
||||
@@ -36,18 +32,21 @@ var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
|
||||
"v1.80.3": 113,
|
||||
"v1.82.0": 115,
|
||||
"v1.82.5": 115,
|
||||
"v1.84.0": 116,
|
||||
"v1.84.1": 116,
|
||||
"v1.84.2": 116,
|
||||
}
|
||||
|
||||
|
||||
var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{
|
||||
87: "v1.60.0",
|
||||
88: "v1.62.0",
|
||||
90: "v1.64.0",
|
||||
95: "v1.66.0",
|
||||
97: "v1.68.0",
|
||||
102: "v1.70.0",
|
||||
104: "v1.72.0",
|
||||
106: "v1.74.0",
|
||||
109: "v1.78.0",
|
||||
113: "v1.80.0",
|
||||
115: "v1.82.0",
|
||||
90: "v1.64.0",
|
||||
95: "v1.66.0",
|
||||
97: "v1.68.0",
|
||||
102: "v1.70.0",
|
||||
104: "v1.72.0",
|
||||
106: "v1.74.0",
|
||||
109: "v1.78.0",
|
||||
113: "v1.80.0",
|
||||
115: "v1.82.0",
|
||||
116: "v1.84.0",
|
||||
}
|
||||
|
@@ -13,11 +13,10 @@ func TestTailscaleLatestMajorMinor(t *testing.T) {
|
||||
stripV bool
|
||||
expected []string
|
||||
}{
|
||||
{3, false, []string{"v1.78", "v1.80", "v1.82"}},
|
||||
{2, true, []string{"1.80", "1.82"}},
|
||||
{3, false, []string{"v1.80", "v1.82", "v1.84"}},
|
||||
{2, true, []string{"1.82", "1.84"}},
|
||||
// Lazy way to see all supported versions
|
||||
{10, true, []string{
|
||||
"1.64",
|
||||
"1.66",
|
||||
"1.68",
|
||||
"1.70",
|
||||
@@ -27,6 +26,7 @@ func TestTailscaleLatestMajorMinor(t *testing.T) {
|
||||
"1.78",
|
||||
"1.80",
|
||||
"1.82",
|
||||
"1.84",
|
||||
}},
|
||||
{0, false, nil},
|
||||
}
|
||||
@@ -46,7 +46,6 @@ func TestCapVerMinimumTailscaleVersion(t *testing.T) {
|
||||
input tailcfg.CapabilityVersion
|
||||
expected string
|
||||
}{
|
||||
{88, "v1.62.0"},
|
||||
{90, "v1.64.0"},
|
||||
{95, "v1.66.0"},
|
||||
{106, "v1.74.0"},
|
||||
|
@@ -1,157 +0,0 @@
|
||||
package main
|
||||
|
||||
//go:generate go run main.go
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
xmaps "golang.org/x/exp/maps"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
const (
|
||||
releasesURL = "https://api.github.com/repos/tailscale/tailscale/releases"
|
||||
rawFileURL = "https://github.com/tailscale/tailscale/raw/refs/tags/%s/tailcfg/tailcfg.go"
|
||||
outputFile = "../capver_generated.go"
|
||||
)
|
||||
|
||||
type Release struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func getCapabilityVersions() (map[string]tailcfg.CapabilityVersion, error) {
|
||||
// Fetch the releases
|
||||
resp, err := http.Get(releasesURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error fetching releases: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading response body: %w", err)
|
||||
}
|
||||
|
||||
var releases []Release
|
||||
err = json.Unmarshal(body, &releases)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshalling JSON: %w", err)
|
||||
}
|
||||
|
||||
// Regular expression to find the CurrentCapabilityVersion line
|
||||
re := regexp.MustCompile(`const CurrentCapabilityVersion CapabilityVersion = (\d+)`)
|
||||
|
||||
versions := make(map[string]tailcfg.CapabilityVersion)
|
||||
|
||||
for _, release := range releases {
|
||||
version := strings.TrimSpace(release.Name)
|
||||
if !strings.HasPrefix(version, "v") {
|
||||
version = "v" + version
|
||||
}
|
||||
|
||||
// Fetch the raw Go file
|
||||
rawURL := fmt.Sprintf(rawFileURL, version)
|
||||
resp, err := http.Get(rawURL)
|
||||
if err != nil {
|
||||
fmt.Printf("Error fetching raw file for version %s: %v\n", version, err)
|
||||
continue
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
fmt.Printf("Error reading raw file for version %s: %v\n", version, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Find the CurrentCapabilityVersion
|
||||
matches := re.FindStringSubmatch(string(body))
|
||||
if len(matches) > 1 {
|
||||
capabilityVersionStr := matches[1]
|
||||
capabilityVersion, _ := strconv.Atoi(capabilityVersionStr)
|
||||
versions[version] = tailcfg.CapabilityVersion(capabilityVersion)
|
||||
} else {
|
||||
fmt.Printf("Version: %s, CurrentCapabilityVersion not found\n", version)
|
||||
}
|
||||
}
|
||||
|
||||
return versions, nil
|
||||
}
|
||||
|
||||
func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion) error {
|
||||
// Open the output file
|
||||
file, err := os.Create(outputFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Write the package declaration and variable
|
||||
file.WriteString("package capver\n\n")
|
||||
file.WriteString("//Generated DO NOT EDIT\n\n")
|
||||
file.WriteString(`import "tailscale.com/tailcfg"`)
|
||||
file.WriteString("\n\n")
|
||||
file.WriteString("var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{\n")
|
||||
|
||||
sortedVersions := xmaps.Keys(versions)
|
||||
sort.Strings(sortedVersions)
|
||||
for _, version := range sortedVersions {
|
||||
file.WriteString(fmt.Sprintf("\t\"%s\": %d,\n", version, versions[version]))
|
||||
}
|
||||
file.WriteString("}\n")
|
||||
|
||||
file.WriteString("\n\n")
|
||||
file.WriteString("var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{\n")
|
||||
|
||||
capVarToTailscaleVer := make(map[tailcfg.CapabilityVersion]string)
|
||||
for _, v := range sortedVersions {
|
||||
cap := versions[v]
|
||||
log.Printf("cap for v: %d, %s", cap, v)
|
||||
|
||||
// If it is already set, skip and continue,
|
||||
// we only want the first tailscale vsion per
|
||||
// capability vsion.
|
||||
if _, ok := capVarToTailscaleVer[cap]; ok {
|
||||
log.Printf("Skipping %d, %s", cap, v)
|
||||
continue
|
||||
}
|
||||
log.Printf("Storing %d, %s", cap, v)
|
||||
capVarToTailscaleVer[cap] = v
|
||||
}
|
||||
|
||||
capsSorted := xmaps.Keys(capVarToTailscaleVer)
|
||||
sort.Slice(capsSorted, func(i, j int) bool {
|
||||
return capsSorted[i] < capsSorted[j]
|
||||
})
|
||||
for _, capVer := range capsSorted {
|
||||
file.WriteString(fmt.Sprintf("\t%d:\t\t\"%s\",\n", capVer, capVarToTailscaleVer[capVer]))
|
||||
}
|
||||
file.WriteString("}\n")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
versions, err := getCapabilityVersions()
|
||||
if err != nil {
|
||||
fmt.Println("Error:", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = writeCapabilityVersionsToFile(versions)
|
||||
if err != nil {
|
||||
fmt.Println("Error writing to file:", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("Capability versions written to", outputFile)
|
||||
}
|
@@ -7,7 +7,6 @@ import (
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -362,8 +361,8 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) {
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(expectedKeys, keys, cmp.Comparer(func(a, b []string) bool {
|
||||
sort.Sort(sort.StringSlice(a))
|
||||
sort.Sort(sort.StringSlice(b))
|
||||
slices.Sort(a)
|
||||
slices.Sort(b)
|
||||
return slices.Equal(a, b)
|
||||
}), cmpopts.IgnoreFields(types.PreAuthKey{}, "User", "CreatedAt", "Reusable", "Ephemeral", "Used", "Expiration")); diff != "" {
|
||||
t.Errorf("TestSQLiteMigrationAndDataValidation() pre-auth key tags migration mismatch (-want +got):\n%s", diff)
|
||||
|
@@ -7,15 +7,19 @@ import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -39,9 +43,7 @@ var (
|
||||
// If no peer IDs are given, all peers are returned.
|
||||
// If at least one peer ID is given, only these peer nodes will be returned.
|
||||
func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||
return ListPeers(rx, nodeID, peerIDs...)
|
||||
})
|
||||
return ListPeers(hsdb.DB, nodeID, peerIDs...)
|
||||
}
|
||||
|
||||
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
|
||||
@@ -66,9 +68,7 @@ func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types
|
||||
// ListNodes queries the database for either all nodes if no parameters are given
|
||||
// or for the given nodes if at least one node ID is given as parameter.
|
||||
func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||
return ListNodes(rx, nodeIDs...)
|
||||
})
|
||||
return ListNodes(hsdb.DB, nodeIDs...)
|
||||
}
|
||||
|
||||
// ListNodes queries the database for either all nodes if no parameters are given
|
||||
@@ -120,9 +120,7 @@ func getNode(tx *gorm.DB, uid types.UserID, name string) (*types.Node, error) {
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetNodeByID(id types.NodeID) (*types.Node, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||
return GetNodeByID(rx, id)
|
||||
})
|
||||
return GetNodeByID(hsdb.DB, id)
|
||||
}
|
||||
|
||||
// GetNodeByID finds a Node by ID and returns the Node struct.
|
||||
@@ -140,9 +138,7 @@ func GetNodeByID(tx *gorm.DB, id types.NodeID) (*types.Node, error) {
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetNodeByMachineKey(machineKey key.MachinePublic) (*types.Node, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||
return GetNodeByMachineKey(rx, machineKey)
|
||||
})
|
||||
return GetNodeByMachineKey(hsdb.DB, machineKey)
|
||||
}
|
||||
|
||||
// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct.
|
||||
@@ -163,9 +159,7 @@ func GetNodeByMachineKey(
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetNodeByNodeKey(nodeKey key.NodePublic) (*types.Node, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||
return GetNodeByNodeKey(rx, nodeKey)
|
||||
})
|
||||
return GetNodeByNodeKey(hsdb.DB, nodeKey)
|
||||
}
|
||||
|
||||
// GetNodeByNodeKey finds a Node by its NodeKey and returns the Node struct.
|
||||
@@ -352,8 +346,8 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
||||
registrationMethod string,
|
||||
ipv4 *netip.Addr,
|
||||
ipv6 *netip.Addr,
|
||||
) (*types.Node, bool, error) {
|
||||
var newNode bool
|
||||
) (*types.Node, change.ChangeSet, error) {
|
||||
var nodeChange change.ChangeSet
|
||||
node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
if reg, ok := hsdb.regCache.Get(registrationID); ok {
|
||||
if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil {
|
||||
@@ -405,7 +399,7 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
||||
}
|
||||
close(reg.Registered)
|
||||
|
||||
newNode = true
|
||||
nodeChange = change.NodeAdded(node.ID)
|
||||
|
||||
return node, err
|
||||
} else {
|
||||
@@ -415,6 +409,8 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodeChange = change.KeyExpiry(node.ID)
|
||||
|
||||
return node, nil
|
||||
}
|
||||
}
|
||||
@@ -422,7 +418,7 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
||||
return nil, ErrNodeNotFoundRegistrationCache
|
||||
})
|
||||
|
||||
return node, newNode, err
|
||||
return node, nodeChange, err
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
|
||||
@@ -448,6 +444,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
||||
if oldNode != nil && oldNode.UserID == node.UserID {
|
||||
node.ID = oldNode.ID
|
||||
node.GivenName = oldNode.GivenName
|
||||
node.ApprovedRoutes = oldNode.ApprovedRoutes
|
||||
ipv4 = oldNode.IPv4
|
||||
ipv6 = oldNode.IPv6
|
||||
}
|
||||
@@ -594,17 +591,18 @@ func ensureUniqueGivenName(
|
||||
// containing the expired nodes, and a boolean indicating if any nodes were found.
|
||||
func ExpireExpiredNodes(tx *gorm.DB,
|
||||
lastCheck time.Time,
|
||||
) (time.Time, types.StateUpdate, bool) {
|
||||
) (time.Time, []change.ChangeSet, bool) {
|
||||
// use the time of the start of the function to ensure we
|
||||
// dont miss some nodes by returning it _after_ we have
|
||||
// checked everything.
|
||||
started := time.Now()
|
||||
|
||||
expired := make([]*tailcfg.PeerChange, 0)
|
||||
var updates []change.ChangeSet
|
||||
|
||||
nodes, err := ListNodes(tx)
|
||||
if err != nil {
|
||||
return time.Unix(0, 0), types.StateUpdate{}, false
|
||||
return time.Unix(0, 0), nil, false
|
||||
}
|
||||
for _, node := range nodes {
|
||||
if node.IsExpired() && node.Expiry.After(lastCheck) {
|
||||
@@ -612,14 +610,15 @@ func ExpireExpiredNodes(tx *gorm.DB,
|
||||
NodeID: tailcfg.NodeID(node.ID),
|
||||
KeyExpiry: node.Expiry,
|
||||
})
|
||||
updates = append(updates, change.KeyExpiry(node.ID))
|
||||
}
|
||||
}
|
||||
|
||||
if len(expired) > 0 {
|
||||
return started, types.UpdatePeerPatch(expired...), true
|
||||
return started, updates, true
|
||||
}
|
||||
|
||||
return started, types.StateUpdate{}, false
|
||||
return started, nil, false
|
||||
}
|
||||
|
||||
// EphemeralGarbageCollector is a garbage collector that will delete nodes after
|
||||
@@ -732,3 +731,114 @@ func (e *EphemeralGarbageCollector) Start() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string) *types.Node {
|
||||
if !testing.Testing() {
|
||||
panic("CreateNodeForTest can only be called during tests")
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
panic("CreateNodeForTest requires a valid user")
|
||||
}
|
||||
|
||||
nodeName := "testnode"
|
||||
if len(hostname) > 0 && hostname[0] != "" {
|
||||
nodeName = hostname[0]
|
||||
}
|
||||
|
||||
// Create a preauth key for the node
|
||||
pak, err := hsdb.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to create preauth key for test node: %v", err))
|
||||
}
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
discoKey := key.NewDisco()
|
||||
|
||||
node := &types.Node{
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
DiscoKey: discoKey.Public(),
|
||||
Hostname: nodeName,
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
}
|
||||
|
||||
err = hsdb.DB.Save(node).Error
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to create test node: %v", err))
|
||||
}
|
||||
|
||||
return node
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname ...string) *types.Node {
|
||||
if !testing.Testing() {
|
||||
panic("CreateRegisteredNodeForTest can only be called during tests")
|
||||
}
|
||||
|
||||
node := hsdb.CreateNodeForTest(user, hostname...)
|
||||
|
||||
err := hsdb.DB.Transaction(func(tx *gorm.DB) error {
|
||||
_, err := RegisterNode(tx, *node, nil, nil)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to register test node: %v", err))
|
||||
}
|
||||
|
||||
registeredNode, err := hsdb.GetNodeByID(node.ID)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to get registered test node: %v", err))
|
||||
}
|
||||
|
||||
return registeredNode
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) CreateNodesForTest(user *types.User, count int, hostnamePrefix ...string) []*types.Node {
|
||||
if !testing.Testing() {
|
||||
panic("CreateNodesForTest can only be called during tests")
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
panic("CreateNodesForTest requires a valid user")
|
||||
}
|
||||
|
||||
prefix := "testnode"
|
||||
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
|
||||
prefix = hostnamePrefix[0]
|
||||
}
|
||||
|
||||
nodes := make([]*types.Node, count)
|
||||
for i := range count {
|
||||
hostname := prefix + "-" + strconv.Itoa(i)
|
||||
nodes[i] = hsdb.CreateNodeForTest(user, hostname)
|
||||
}
|
||||
|
||||
return nodes
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) CreateRegisteredNodesForTest(user *types.User, count int, hostnamePrefix ...string) []*types.Node {
|
||||
if !testing.Testing() {
|
||||
panic("CreateRegisteredNodesForTest can only be called during tests")
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
panic("CreateRegisteredNodesForTest requires a valid user")
|
||||
}
|
||||
|
||||
prefix := "testnode"
|
||||
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
|
||||
prefix = hostnamePrefix[0]
|
||||
}
|
||||
|
||||
nodes := make([]*types.Node, count)
|
||||
for i := range count {
|
||||
hostname := prefix + "-" + strconv.Itoa(i)
|
||||
nodes[i] = hsdb.CreateRegisteredNodeForTest(user, hostname)
|
||||
}
|
||||
|
||||
return nodes
|
||||
}
|
||||
|
@@ -6,7 +6,6 @@ import (
|
||||
"math/big"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -26,82 +25,36 @@ import (
|
||||
)
|
||||
|
||||
func (s *Suite) TestGetNode(c *check.C) {
|
||||
user, err := db.CreateUser(types.User{Name: "test"})
|
||||
c.Assert(err, check.IsNil)
|
||||
user := db.CreateUserForTest("test")
|
||||
|
||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||
_, err := db.getNode(types.UserID(user.ID), "testnode")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
node := &types.Node{
|
||||
ID: 0,
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
}
|
||||
trx := db.DB.Save(node)
|
||||
c.Assert(trx.Error, check.IsNil)
|
||||
node := db.CreateNodeForTest(user, "testnode")
|
||||
|
||||
_, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(node.Hostname, check.Equals, "testnode")
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetNodeByID(c *check.C) {
|
||||
user, err := db.CreateUser(types.User{Name: "test"})
|
||||
c.Assert(err, check.IsNil)
|
||||
user := db.CreateUserForTest("test")
|
||||
|
||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetNodeByID(0)
|
||||
_, err := db.GetNodeByID(0)
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
node := db.CreateNodeForTest(user, "testnode")
|
||||
|
||||
node := types.Node{
|
||||
ID: 0,
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
}
|
||||
trx := db.DB.Save(&node)
|
||||
c.Assert(trx.Error, check.IsNil)
|
||||
|
||||
_, err = db.GetNodeByID(0)
|
||||
retrievedNode, err := db.GetNodeByID(node.ID)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(retrievedNode.Hostname, check.Equals, "testnode")
|
||||
}
|
||||
|
||||
func (s *Suite) TestHardDeleteNode(c *check.C) {
|
||||
user, err := db.CreateUser(types.User{Name: "test"})
|
||||
c.Assert(err, check.IsNil)
|
||||
user := db.CreateUserForTest("test")
|
||||
node := db.CreateNodeForTest(user, "testnode3")
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
node := types.Node{
|
||||
ID: 0,
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "testnode3",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
}
|
||||
trx := db.DB.Save(&node)
|
||||
c.Assert(trx.Error, check.IsNil)
|
||||
|
||||
err = db.DeleteNode(&node)
|
||||
err := db.DeleteNode(node)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.getNode(types.UserID(user.ID), "testnode3")
|
||||
@@ -109,42 +62,21 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
|
||||
}
|
||||
|
||||
func (s *Suite) TestListPeers(c *check.C) {
|
||||
user, err := db.CreateUser(types.User{Name: "test"})
|
||||
c.Assert(err, check.IsNil)
|
||||
user := db.CreateUserForTest("test")
|
||||
|
||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetNodeByID(0)
|
||||
_, err := db.GetNodeByID(0)
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
for index := range 11 {
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
nodes := db.CreateNodesForTest(user, 11, "testnode")
|
||||
|
||||
node := types.Node{
|
||||
ID: types.NodeID(index),
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "testnode" + strconv.Itoa(index),
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
}
|
||||
trx := db.DB.Save(&node)
|
||||
c.Assert(trx.Error, check.IsNil)
|
||||
}
|
||||
|
||||
node0ByID, err := db.GetNodeByID(0)
|
||||
firstNode := nodes[0]
|
||||
peersOfFirstNode, err := db.ListPeers(firstNode.ID)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
peersOfNode0, err := db.ListPeers(node0ByID.ID)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(peersOfNode0), check.Equals, 9)
|
||||
c.Assert(peersOfNode0[0].Hostname, check.Equals, "testnode2")
|
||||
c.Assert(peersOfNode0[5].Hostname, check.Equals, "testnode7")
|
||||
c.Assert(peersOfNode0[8].Hostname, check.Equals, "testnode10")
|
||||
c.Assert(len(peersOfFirstNode), check.Equals, 10)
|
||||
c.Assert(peersOfFirstNode[0].Hostname, check.Equals, "testnode-1")
|
||||
c.Assert(peersOfFirstNode[5].Hostname, check.Equals, "testnode-6")
|
||||
c.Assert(peersOfFirstNode[9].Hostname, check.Equals, "testnode-10")
|
||||
}
|
||||
|
||||
func (s *Suite) TestExpireNode(c *check.C) {
|
||||
@@ -807,13 +739,13 @@ func TestListPeers(t *testing.T) {
|
||||
// No parameter means no filter, should return all peers
|
||||
nodes, err = db.ListPeers(1)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
|
||||
// Empty node list should return all peers
|
||||
nodes, err = db.ListPeers(1, types.NodeIDs{}...)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
|
||||
// No match in IDs should return empty list and no error
|
||||
@@ -824,13 +756,13 @@ func TestListPeers(t *testing.T) {
|
||||
// Partial match in IDs
|
||||
nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
|
||||
// Several matched IDs, but node ID is still filtered out
|
||||
nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
}
|
||||
|
||||
@@ -892,14 +824,14 @@ func TestListNodes(t *testing.T) {
|
||||
// No parameter means no filter, should return all nodes
|
||||
nodes, err = db.ListNodes()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 2)
|
||||
assert.Equal(t, 2, len(nodes))
|
||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||
|
||||
// Empty node list should return all nodes
|
||||
nodes, err = db.ListNodes(types.NodeIDs{}...)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 2)
|
||||
assert.Equal(t, 2, len(nodes))
|
||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||
|
||||
@@ -911,13 +843,13 @@ func TestListNodes(t *testing.T) {
|
||||
// Partial match in IDs
|
||||
nodes, err = db.ListNodes(types.NodeIDs{2, 3}...)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
|
||||
// Several matched IDs
|
||||
nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 2)
|
||||
assert.Equal(t, 2, len(nodes))
|
||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||
}
|
||||
|
@@ -109,9 +109,7 @@ func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, e
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) {
|
||||
return GetPreAuthKey(rx, key)
|
||||
})
|
||||
return GetPreAuthKey(hsdb.DB, key)
|
||||
}
|
||||
|
||||
// GetPreAuthKey returns a PreAuthKey for a given key. The caller is responsible
|
||||
@@ -155,11 +153,8 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
||||
|
||||
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
||||
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
||||
if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
now := time.Now()
|
||||
return tx.Model(&types.PreAuthKey{}).Where("id = ?", k.ID).Update("expiration", now).Error
|
||||
}
|
||||
|
||||
func generateKey() (string, error) {
|
||||
|
@@ -1,7 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
@@ -57,7 +57,7 @@ func (*Suite) TestPreAuthKeyACLTags(c *check.C) {
|
||||
listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID))
|
||||
c.Assert(err, check.IsNil)
|
||||
gotTags := listedPaks[0].Proto().GetAclTags()
|
||||
sort.Sort(sort.StringSlice(gotTags))
|
||||
slices.Sort(gotTags)
|
||||
c.Assert(gotTags, check.DeepEquals, tags)
|
||||
}
|
||||
|
||||
|
@@ -3,6 +3,8 @@ package db
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
@@ -110,9 +112,7 @@ func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error {
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetUserByID(uid types.UserID) (*types.User, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
|
||||
return GetUserByID(rx, uid)
|
||||
})
|
||||
return GetUserByID(hsdb.DB, uid)
|
||||
}
|
||||
|
||||
func GetUserByID(tx *gorm.DB, uid types.UserID) (*types.User, error) {
|
||||
@@ -146,9 +146,7 @@ func GetUserByOIDCIdentifier(tx *gorm.DB, id string) (*types.User, error) {
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
|
||||
return ListUsers(rx, where...)
|
||||
})
|
||||
return ListUsers(hsdb.DB, where...)
|
||||
}
|
||||
|
||||
// ListUsers gets all the existing users.
|
||||
@@ -217,3 +215,40 @@ func AssignNodeToUser(tx *gorm.DB, nodeID types.NodeID, uid types.UserID) error
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) CreateUserForTest(name ...string) *types.User {
|
||||
if !testing.Testing() {
|
||||
panic("CreateUserForTest can only be called during tests")
|
||||
}
|
||||
|
||||
userName := "testuser"
|
||||
if len(name) > 0 && name[0] != "" {
|
||||
userName = name[0]
|
||||
}
|
||||
|
||||
user, err := hsdb.CreateUser(types.User{Name: userName})
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to create test user: %v", err))
|
||||
}
|
||||
|
||||
return user
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) CreateUsersForTest(count int, namePrefix ...string) []*types.User {
|
||||
if !testing.Testing() {
|
||||
panic("CreateUsersForTest can only be called during tests")
|
||||
}
|
||||
|
||||
prefix := "testuser"
|
||||
if len(namePrefix) > 0 && namePrefix[0] != "" {
|
||||
prefix = namePrefix[0]
|
||||
}
|
||||
|
||||
users := make([]*types.User, count)
|
||||
for i := range count {
|
||||
name := prefix + "-" + strconv.Itoa(i)
|
||||
users[i] = hsdb.CreateUserForTest(name)
|
||||
}
|
||||
|
||||
return users
|
||||
}
|
||||
|
@@ -11,8 +11,7 @@ import (
|
||||
)
|
||||
|
||||
func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
|
||||
user, err := db.CreateUser(types.User{Name: "test"})
|
||||
c.Assert(err, check.IsNil)
|
||||
user := db.CreateUserForTest("test")
|
||||
c.Assert(user.Name, check.Equals, "test")
|
||||
|
||||
users, err := db.ListUsers()
|
||||
@@ -30,8 +29,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
||||
err := db.DestroyUser(9998)
|
||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||
|
||||
user, err := db.CreateUser(types.User{Name: "test"})
|
||||
c.Assert(err, check.IsNil)
|
||||
user := db.CreateUserForTest("test")
|
||||
|
||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
@@ -64,8 +62,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
||||
}
|
||||
|
||||
func (s *Suite) TestRenameUser(c *check.C) {
|
||||
userTest, err := db.CreateUser(types.User{Name: "test"})
|
||||
c.Assert(err, check.IsNil)
|
||||
userTest := db.CreateUserForTest("test")
|
||||
c.Assert(userTest.Name, check.Equals, "test")
|
||||
|
||||
users, err := db.ListUsers()
|
||||
@@ -86,8 +83,7 @@ func (s *Suite) TestRenameUser(c *check.C) {
|
||||
err = db.RenameUser(99988, "test")
|
||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||
|
||||
userTest2, err := db.CreateUser(types.User{Name: "test2"})
|
||||
c.Assert(err, check.IsNil)
|
||||
userTest2 := db.CreateUserForTest("test2")
|
||||
c.Assert(userTest2.Name, check.Equals, "test2")
|
||||
|
||||
want := "UNIQUE constraint failed"
|
||||
@@ -98,11 +94,8 @@ func (s *Suite) TestRenameUser(c *check.C) {
|
||||
}
|
||||
|
||||
func (s *Suite) TestSetMachineUser(c *check.C) {
|
||||
oldUser, err := db.CreateUser(types.User{Name: "old"})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
newUser, err := db.CreateUser(types.User{Name: "new"})
|
||||
c.Assert(err, check.IsNil)
|
||||
oldUser := db.CreateUserForTest("old")
|
||||
newUser := db.CreateUserForTest("new")
|
||||
|
||||
pak, err := db.CreatePreAuthKey(types.UserID(oldUser.ID), false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
@@ -17,10 +17,6 @@ import (
|
||||
func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
debugMux := http.NewServeMux()
|
||||
debug := tsweb.Debugger(debugMux)
|
||||
debug.Handle("notifier", "Connected nodes in notifier", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(h.nodeNotifier.String()))
|
||||
}))
|
||||
debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
config, err := json.MarshalIndent(h.cfg, "", " ")
|
||||
if err != nil {
|
||||
|
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -72,9 +73,7 @@ func mergeDERPMaps(derpMaps []*tailcfg.DERPMap) *tailcfg.DERPMap {
|
||||
}
|
||||
|
||||
for _, derpMap := range derpMaps {
|
||||
for id, region := range derpMap.Regions {
|
||||
result.Regions[id] = region
|
||||
}
|
||||
maps.Copy(result.Regions, derpMap.Regions)
|
||||
}
|
||||
|
||||
return &result
|
||||
|
@@ -1,3 +1,5 @@
|
||||
//go:generate buf generate --template ../buf.gen.yaml -o .. ../proto
|
||||
|
||||
// nolint
|
||||
package hscontrol
|
||||
|
||||
@@ -27,6 +29,7 @@ import (
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
)
|
||||
|
||||
@@ -56,12 +59,14 @@ func (api headscaleV1APIServer) CreateUser(
|
||||
return nil, status.Errorf(codes.Internal, "failed to create user: %s", err)
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
|
||||
c := change.UserAdded(types.UserID(user.ID))
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-user-created", user.Name)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
c.Change = change.Policy
|
||||
}
|
||||
|
||||
api.h.Change(c)
|
||||
|
||||
return &v1.CreateUserResponse{User: user.Proto()}, nil
|
||||
}
|
||||
|
||||
@@ -81,8 +86,7 @@ func (api headscaleV1APIServer) RenameUser(
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-user-renamed", request.GetNewName())
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
api.h.Change(change.PolicyChange())
|
||||
}
|
||||
|
||||
newUser, err := api.h.state.GetUserByName(request.GetNewName())
|
||||
@@ -107,6 +111,8 @@ func (api headscaleV1APIServer) DeleteUser(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
api.h.Change(change.UserRemoved(types.UserID(user.ID)))
|
||||
|
||||
return &v1.DeleteUserResponse{}, nil
|
||||
}
|
||||
|
||||
@@ -246,7 +252,7 @@ func (api headscaleV1APIServer) RegisterNode(
|
||||
return nil, fmt.Errorf("looking up user: %w", err)
|
||||
}
|
||||
|
||||
node, _, err := api.h.state.HandleNodeFromAuthPath(
|
||||
node, nodeChange, err := api.h.state.HandleNodeFromAuthPath(
|
||||
registrationId,
|
||||
types.UserID(user.ID),
|
||||
nil,
|
||||
@@ -267,22 +273,13 @@ func (api headscaleV1APIServer) RegisterNode(
|
||||
// ensure we send an update.
|
||||
// This works, but might be another good candidate for doing some sort of
|
||||
// eventbus.
|
||||
routesChanged := api.h.state.AutoApproveRoutes(node)
|
||||
_, policyChanged, err := api.h.state.SaveNode(node)
|
||||
_ = api.h.state.AutoApproveRoutes(node)
|
||||
_, _, err = api.h.state.SaveNode(node)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed (from SaveNode or route changes)
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-nodes-change", "all")
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
if routesChanged {
|
||||
ctx = types.NotifyCtx(context.Background(), "web-node-login", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID))
|
||||
}
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
|
||||
}
|
||||
@@ -300,7 +297,7 @@ func (api headscaleV1APIServer) GetNode(
|
||||
|
||||
// Populate the online field based on
|
||||
// currently connected nodes.
|
||||
resp.Online = api.h.nodeNotifier.IsConnected(node.ID)
|
||||
resp.Online = api.h.mapBatcher.IsConnected(node.ID)
|
||||
|
||||
return &v1.GetNodeResponse{Node: resp}, nil
|
||||
}
|
||||
@@ -316,21 +313,14 @@ func (api headscaleV1APIServer) SetTags(
|
||||
}
|
||||
}
|
||||
|
||||
node, policyChanged, err := api.h.state.SetNodeTags(types.NodeID(request.GetNodeId()), request.GetTags())
|
||||
node, nodeChange, err := api.h.state.SetNodeTags(types.NodeID(request.GetNodeId()), request.GetTags())
|
||||
if err != nil {
|
||||
return &v1.SetTagsResponse{
|
||||
Node: nil,
|
||||
}, status.Error(codes.InvalidArgument, err.Error())
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-node-tags", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
ctx = types.NotifyCtx(ctx, "cli-settags", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
log.Trace().
|
||||
Str("node", node.Hostname).
|
||||
@@ -362,23 +352,19 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
|
||||
tsaddr.SortPrefixes(routes)
|
||||
routes = slices.Compact(routes)
|
||||
|
||||
node, policyChanged, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes)
|
||||
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.InvalidArgument, err.Error())
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-routes-approved", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
routeChange := api.h.state.SetNodeRoutes(node.ID, node.SubnetRoutes()...)
|
||||
|
||||
if api.h.state.SetNodeRoutes(node.ID, node.SubnetRoutes()...) {
|
||||
ctx := types.NotifyCtx(ctx, "poll-primary-change", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
} else {
|
||||
ctx = types.NotifyCtx(ctx, "cli-approveroutes", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
||||
// Always propagate node changes from SetApprovedRoutes
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
// If routes changed, propagate those changes too
|
||||
if !routeChange.Empty() {
|
||||
api.h.Change(routeChange)
|
||||
}
|
||||
|
||||
proto := node.Proto()
|
||||
@@ -409,19 +395,12 @@ func (api headscaleV1APIServer) DeleteNode(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
policyChanged, err := api.h.state.DeleteNode(node)
|
||||
nodeChange, err := api.h.state.DeleteNode(node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-node-deleted", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
ctx = types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
return &v1.DeleteNodeResponse{}, nil
|
||||
}
|
||||
@@ -432,25 +411,13 @@ func (api headscaleV1APIServer) ExpireNode(
|
||||
) (*v1.ExpireNodeResponse, error) {
|
||||
now := time.Now()
|
||||
|
||||
node, policyChanged, err := api.h.state.SetNodeExpiry(types.NodeID(request.GetNodeId()), now)
|
||||
node, nodeChange, err := api.h.state.SetNodeExpiry(types.NodeID(request.GetNodeId()), now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-node-expired", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
ctx = types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyByNodeID(
|
||||
ctx,
|
||||
types.UpdateSelf(node.ID),
|
||||
node.ID)
|
||||
|
||||
ctx = types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, now), node.ID)
|
||||
// TODO(kradalby): Ensure that both the selfupdate and peer updates are sent
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
log.Trace().
|
||||
Str("node", node.Hostname).
|
||||
@@ -464,22 +431,13 @@ func (api headscaleV1APIServer) RenameNode(
|
||||
ctx context.Context,
|
||||
request *v1.RenameNodeRequest,
|
||||
) (*v1.RenameNodeResponse, error) {
|
||||
node, policyChanged, err := api.h.state.RenameNode(types.NodeID(request.GetNodeId()), request.GetNewName())
|
||||
node, nodeChange, err := api.h.state.RenameNode(types.NodeID(request.GetNodeId()), request.GetNewName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-node-renamed", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
ctx = types.NotifyCtx(ctx, "cli-renamenode-self", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyByNodeID(ctx, types.UpdateSelf(node.ID), node.ID)
|
||||
|
||||
ctx = types.NotifyCtx(ctx, "cli-renamenode-peers", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
||||
// TODO(kradalby): investigate if we need selfupdate
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
log.Trace().
|
||||
Str("node", node.Hostname).
|
||||
@@ -498,7 +456,7 @@ func (api headscaleV1APIServer) ListNodes(
|
||||
// probably be done once.
|
||||
// TODO(kradalby): This should be done in one tx.
|
||||
|
||||
isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap()
|
||||
IsConnected := api.h.mapBatcher.ConnectedMap()
|
||||
if request.GetUser() != "" {
|
||||
user, err := api.h.state.GetUserByName(request.GetUser())
|
||||
if err != nil {
|
||||
@@ -510,7 +468,7 @@ func (api headscaleV1APIServer) ListNodes(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response := nodesToProto(api.h.state, isLikelyConnected, nodes)
|
||||
response := nodesToProto(api.h.state, IsConnected, nodes)
|
||||
return &v1.ListNodesResponse{Nodes: response}, nil
|
||||
}
|
||||
|
||||
@@ -523,18 +481,18 @@ func (api headscaleV1APIServer) ListNodes(
|
||||
return nodes[i].ID < nodes[j].ID
|
||||
})
|
||||
|
||||
response := nodesToProto(api.h.state, isLikelyConnected, nodes)
|
||||
response := nodesToProto(api.h.state, IsConnected, nodes)
|
||||
return &v1.ListNodesResponse{Nodes: response}, nil
|
||||
}
|
||||
|
||||
func nodesToProto(state *state.State, isLikelyConnected *xsync.MapOf[types.NodeID, bool], nodes types.Nodes) []*v1.Node {
|
||||
func nodesToProto(state *state.State, IsConnected *xsync.MapOf[types.NodeID, bool], nodes types.Nodes) []*v1.Node {
|
||||
response := make([]*v1.Node, len(nodes))
|
||||
for index, node := range nodes {
|
||||
resp := node.Proto()
|
||||
|
||||
// Populate the online field based on
|
||||
// currently connected nodes.
|
||||
if val, ok := isLikelyConnected.Load(node.ID); ok && val {
|
||||
if val, ok := IsConnected.Load(node.ID); ok && val {
|
||||
resp.Online = true
|
||||
}
|
||||
|
||||
@@ -556,24 +514,14 @@ func (api headscaleV1APIServer) MoveNode(
|
||||
ctx context.Context,
|
||||
request *v1.MoveNodeRequest,
|
||||
) (*v1.MoveNodeResponse, error) {
|
||||
node, policyChanged, err := api.h.state.AssignNodeToUser(types.NodeID(request.GetNodeId()), types.UserID(request.GetUser()))
|
||||
node, nodeChange, err := api.h.state.AssignNodeToUser(types.NodeID(request.GetNodeId()), types.UserID(request.GetUser()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-node-moved", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
ctx = types.NotifyCtx(ctx, "cli-movenode-self", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyByNodeID(
|
||||
ctx,
|
||||
types.UpdateSelf(node.ID),
|
||||
node.ID)
|
||||
ctx = types.NotifyCtx(ctx, "cli-movenode", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
||||
// TODO(kradalby): Ensure the policy is also sent
|
||||
// TODO(kradalby): ensure that both the selfupdate and peer updates are sent
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
return &v1.MoveNodeResponse{Node: node.Proto()}, nil
|
||||
}
|
||||
@@ -754,8 +702,7 @@ func (api headscaleV1APIServer) SetPolicy(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
api.h.Change(change.PolicyChange())
|
||||
}
|
||||
|
||||
response := &v1.SetPolicyResponse{
|
||||
|
155
hscontrol/mapper/batcher.go
Normal file
155
hscontrol/mapper/batcher.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
type batcherFunc func(cfg *types.Config, state *state.State) Batcher
|
||||
|
||||
// Batcher defines the common interface for all batcher implementations.
|
||||
type Batcher interface {
|
||||
Start()
|
||||
Close()
|
||||
AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error
|
||||
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool)
|
||||
IsConnected(id types.NodeID) bool
|
||||
ConnectedMap() *xsync.Map[types.NodeID, bool]
|
||||
AddWork(c change.ChangeSet)
|
||||
MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error)
|
||||
}
|
||||
|
||||
func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeBatcher {
|
||||
return &LockFreeBatcher{
|
||||
mapper: mapper,
|
||||
workers: workers,
|
||||
tick: time.NewTicker(batchTime),
|
||||
|
||||
// The size of this channel is arbitrary chosen, the sizing should be revisited.
|
||||
workCh: make(chan work, workers*200),
|
||||
nodes: xsync.NewMap[types.NodeID, *nodeConn](),
|
||||
connected: xsync.NewMap[types.NodeID, *time.Time](),
|
||||
pendingChanges: xsync.NewMap[types.NodeID, []change.ChangeSet](),
|
||||
}
|
||||
}
|
||||
|
||||
// NewBatcherAndMapper creates a Batcher implementation.
|
||||
func NewBatcherAndMapper(cfg *types.Config, state *state.State) Batcher {
|
||||
m := newMapper(cfg, state)
|
||||
b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m)
|
||||
m.batcher = b
|
||||
return b
|
||||
}
|
||||
|
||||
// nodeConnection interface for different connection implementations.
|
||||
type nodeConnection interface {
|
||||
nodeID() types.NodeID
|
||||
version() tailcfg.CapabilityVersion
|
||||
send(data *tailcfg.MapResponse) error
|
||||
}
|
||||
|
||||
// generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID that is based on the provided [change.ChangeSet].
|
||||
func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion, mapper *mapper, c change.ChangeSet) (*tailcfg.MapResponse, error) {
|
||||
if c.Empty() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Validate inputs before processing
|
||||
if nodeID == 0 {
|
||||
return nil, fmt.Errorf("invalid nodeID: %d", nodeID)
|
||||
}
|
||||
|
||||
if mapper == nil {
|
||||
return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID)
|
||||
}
|
||||
|
||||
var mapResp *tailcfg.MapResponse
|
||||
var err error
|
||||
|
||||
switch c.Change {
|
||||
case change.DERP:
|
||||
mapResp, err = mapper.derpMapResponse(nodeID)
|
||||
|
||||
case change.NodeCameOnline, change.NodeWentOffline:
|
||||
if c.IsSubnetRouter {
|
||||
// TODO(kradalby): This can potentially be a peer update of the old and new subnet router.
|
||||
mapResp, err = mapper.fullMapResponse(nodeID, version)
|
||||
} else {
|
||||
mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: c.NodeID.NodeID(),
|
||||
Online: ptr.To(c.Change == change.NodeCameOnline),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
case change.NodeNewOrUpdate:
|
||||
mapResp, err = mapper.fullMapResponse(nodeID, version)
|
||||
|
||||
case change.NodeRemove:
|
||||
mapResp, err = mapper.peerRemovedResponse(nodeID, c.NodeID)
|
||||
|
||||
default:
|
||||
// The following will always hit this:
|
||||
// change.Full, change.Policy
|
||||
mapResp, err = mapper.fullMapResponse(nodeID, version)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generating map response for nodeID %d: %w", nodeID, err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): Is this necessary?
|
||||
// Validate the generated map response - only check for nil response
|
||||
// Note: mapResp.Node can be nil for peer updates, which is valid
|
||||
if mapResp == nil && c.Change != change.DERP && c.Change != change.NodeRemove {
|
||||
return nil, fmt.Errorf("generated nil map response for nodeID %d change %s", nodeID, c.Change.String())
|
||||
}
|
||||
|
||||
return mapResp, nil
|
||||
}
|
||||
|
||||
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet].
|
||||
func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error {
|
||||
if nc == nil {
|
||||
return fmt.Errorf("nodeConnection is nil")
|
||||
}
|
||||
|
||||
nodeID := nc.nodeID()
|
||||
data, err := generateMapResponse(nodeID, nc.version(), mapper, c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generating map response for node %d: %w", nodeID, err)
|
||||
}
|
||||
|
||||
if data == nil {
|
||||
// No data to send is valid for some change types
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send the map response
|
||||
if err := nc.send(data); err != nil {
|
||||
return fmt.Errorf("sending map response to node %d: %w", nodeID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// workResult represents the result of processing a change.
|
||||
type workResult struct {
|
||||
mapResponse *tailcfg.MapResponse
|
||||
err error
|
||||
}
|
||||
|
||||
// work represents a unit of work to be processed by workers.
|
||||
type work struct {
|
||||
c change.ChangeSet
|
||||
nodeID types.NodeID
|
||||
resultCh chan<- workResult // optional channel for synchronous operations
|
||||
}
|
491
hscontrol/mapper/batcher_lockfree.go
Normal file
491
hscontrol/mapper/batcher_lockfree.go
Normal file
@@ -0,0 +1,491 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention.
|
||||
type LockFreeBatcher struct {
|
||||
tick *time.Ticker
|
||||
mapper *mapper
|
||||
workers int
|
||||
|
||||
// Lock-free concurrent maps
|
||||
nodes *xsync.Map[types.NodeID, *nodeConn]
|
||||
connected *xsync.Map[types.NodeID, *time.Time]
|
||||
|
||||
// Work queue channel
|
||||
workCh chan work
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// Batching state
|
||||
pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet]
|
||||
batchMutex sync.RWMutex
|
||||
|
||||
// Metrics
|
||||
totalNodes atomic.Int64
|
||||
totalUpdates atomic.Int64
|
||||
workQueuedCount atomic.Int64
|
||||
workProcessed atomic.Int64
|
||||
workErrors atomic.Int64
|
||||
}
|
||||
|
||||
// AddNode registers a new node connection with the batcher and sends an initial map response.
|
||||
// It creates or updates the node's connection data, validates the initial map generation,
|
||||
// and notifies other nodes that this node has come online.
|
||||
// TODO(kradalby): See if we can move the isRouter argument somewhere else.
|
||||
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error {
|
||||
// First validate that we can generate initial map before doing anything else
|
||||
fullSelfChange := change.FullSelf(id)
|
||||
|
||||
// TODO(kradalby): This should not be generated here, but rather in MapResponseFromChange.
|
||||
// This currently means that the goroutine for the node connection will do the processing
|
||||
// which means that we might have uncontrolled concurrency.
|
||||
// When we use MapResponseFromChange, it will be processed by the same worker pool, causing
|
||||
// it to be processed in a more controlled manner.
|
||||
initialMap, err := generateMapResponse(id, version, b.mapper, fullSelfChange)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
|
||||
}
|
||||
|
||||
// Only after validation succeeds, create or update node connection
|
||||
newConn := newNodeConn(id, c, version, b.mapper)
|
||||
|
||||
var conn *nodeConn
|
||||
if existing, loaded := b.nodes.LoadOrStore(id, newConn); loaded {
|
||||
// Update existing connection
|
||||
existing.updateConnection(c, version)
|
||||
conn = existing
|
||||
} else {
|
||||
b.totalNodes.Add(1)
|
||||
conn = newConn
|
||||
}
|
||||
|
||||
// Mark as connected only after validation succeeds
|
||||
b.connected.Store(id, nil) // nil = connected
|
||||
|
||||
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node connected to batcher")
|
||||
|
||||
// Send the validated initial map
|
||||
if initialMap != nil {
|
||||
if err := conn.send(initialMap); err != nil {
|
||||
// Clean up the connection state on send failure
|
||||
b.nodes.Delete(id)
|
||||
b.connected.Delete(id)
|
||||
return fmt.Errorf("failed to send initial map to node %d: %w", id, err)
|
||||
}
|
||||
|
||||
// Notify other nodes that this node came online
|
||||
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeCameOnline, IsSubnetRouter: isRouter})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state.
|
||||
// It validates the connection channel matches the current one, closes the connection,
|
||||
// and notifies other nodes that this node has gone offline.
|
||||
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool) {
|
||||
// Check if this is the current connection and mark it as closed
|
||||
if existing, ok := b.nodes.Load(id); ok {
|
||||
if !existing.matchesChannel(c) {
|
||||
log.Debug().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-current connection, ignoring")
|
||||
return // Not the current connection, not an error
|
||||
}
|
||||
|
||||
// Mark the connection as closed to prevent further sends
|
||||
if connData := existing.connData.Load(); connData != nil {
|
||||
connData.closed.Store(true)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node disconnected from batcher, marking as offline")
|
||||
|
||||
// Remove node and mark disconnected atomically
|
||||
b.nodes.Delete(id)
|
||||
b.connected.Store(id, ptr.To(time.Now()))
|
||||
b.totalNodes.Add(-1)
|
||||
|
||||
// Notify other nodes that this node went offline
|
||||
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeWentOffline, IsSubnetRouter: isRouter})
|
||||
}
|
||||
|
||||
// AddWork queues a change to be processed by the batcher.
|
||||
// Critical changes are processed immediately, while others are batched for efficiency.
|
||||
func (b *LockFreeBatcher) AddWork(c change.ChangeSet) {
|
||||
b.addWork(c)
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) Start() {
|
||||
b.ctx, b.cancel = context.WithCancel(context.Background())
|
||||
go b.doWork()
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) Close() {
|
||||
if b.cancel != nil {
|
||||
b.cancel()
|
||||
}
|
||||
close(b.workCh)
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) doWork() {
|
||||
log.Debug().Msg("batcher doWork loop started")
|
||||
defer log.Debug().Msg("batcher doWork loop stopped")
|
||||
|
||||
for i := range b.workers {
|
||||
go b.worker(i + 1)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-b.tick.C:
|
||||
// Process batched changes
|
||||
b.processBatchedChanges()
|
||||
case <-b.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) worker(workerID int) {
|
||||
log.Debug().Int("workerID", workerID).Msg("batcher worker started")
|
||||
defer log.Debug().Int("workerID", workerID).Msg("batcher worker stopped")
|
||||
|
||||
for {
|
||||
select {
|
||||
case w, ok := <-b.workCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
b.workProcessed.Add(1)
|
||||
|
||||
// If the resultCh is set, it means that this is a work request
|
||||
// where there is a blocking function waiting for the map that
|
||||
// is being generated.
|
||||
// This is used for synchronous map generation.
|
||||
if w.resultCh != nil {
|
||||
var result workResult
|
||||
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
||||
result.mapResponse, result.err = generateMapResponse(nc.nodeID(), nc.version(), b.mapper, w.c)
|
||||
if result.err != nil {
|
||||
b.workErrors.Add(1)
|
||||
log.Error().Err(result.err).
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Msg("failed to generate map response for synchronous work")
|
||||
}
|
||||
} else {
|
||||
result.err = fmt.Errorf("node %d not found", w.nodeID)
|
||||
b.workErrors.Add(1)
|
||||
log.Error().Err(result.err).
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Msg("node not found for synchronous work")
|
||||
}
|
||||
|
||||
// Send result
|
||||
select {
|
||||
case w.resultCh <- result:
|
||||
case <-b.ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
if duration > 100*time.Millisecond {
|
||||
log.Warn().
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Dur("duration", duration).
|
||||
Msg("slow synchronous work processing")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// If resultCh is nil, this is an asynchronous work request
|
||||
// that should be processed and sent to the node instead of
|
||||
// returned to the caller.
|
||||
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
||||
// Check if this connection is still active before processing
|
||||
if connData := nc.connData.Load(); connData != nil && connData.closed.Load() {
|
||||
log.Debug().
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Msg("skipping work for closed connection")
|
||||
continue
|
||||
}
|
||||
|
||||
err := nc.change(w.c)
|
||||
if err != nil {
|
||||
b.workErrors.Add(1)
|
||||
log.Error().Err(err).
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.c.NodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Msg("failed to apply change")
|
||||
}
|
||||
} else {
|
||||
log.Debug().
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Msg("node not found for asynchronous work - node may have disconnected")
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
if duration > 100*time.Millisecond {
|
||||
log.Warn().
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Dur("duration", duration).
|
||||
Msg("slow asynchronous work processing")
|
||||
}
|
||||
|
||||
case <-b.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) addWork(c change.ChangeSet) {
|
||||
// For critical changes that need immediate processing, send directly
|
||||
if b.shouldProcessImmediately(c) {
|
||||
if c.SelfUpdateOnly {
|
||||
b.queueWork(work{c: c, nodeID: c.NodeID, resultCh: nil})
|
||||
return
|
||||
}
|
||||
b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool {
|
||||
if c.NodeID == nodeID && !c.AlsoSelf() {
|
||||
return true
|
||||
}
|
||||
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// For non-critical changes, add to batch
|
||||
b.addToBatch(c)
|
||||
}
|
||||
|
||||
// queueWork safely queues work
|
||||
func (b *LockFreeBatcher) queueWork(w work) {
|
||||
b.workQueuedCount.Add(1)
|
||||
|
||||
select {
|
||||
case b.workCh <- w:
|
||||
// Successfully queued
|
||||
case <-b.ctx.Done():
|
||||
// Batcher is shutting down
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// shouldProcessImmediately determines if a change should bypass batching
|
||||
func (b *LockFreeBatcher) shouldProcessImmediately(c change.ChangeSet) bool {
|
||||
// Process these changes immediately to avoid delaying critical functionality
|
||||
switch c.Change {
|
||||
case change.Full, change.NodeRemove, change.NodeCameOnline, change.NodeWentOffline, change.Policy:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// addToBatch adds a change to the pending batch
|
||||
func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) {
|
||||
b.batchMutex.Lock()
|
||||
defer b.batchMutex.Unlock()
|
||||
|
||||
if c.SelfUpdateOnly {
|
||||
changes, _ := b.pendingChanges.LoadOrStore(c.NodeID, []change.ChangeSet{})
|
||||
changes = append(changes, c)
|
||||
b.pendingChanges.Store(c.NodeID, changes)
|
||||
return
|
||||
}
|
||||
|
||||
b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool {
|
||||
if c.NodeID == nodeID && !c.AlsoSelf() {
|
||||
return true
|
||||
}
|
||||
|
||||
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{})
|
||||
changes = append(changes, c)
|
||||
b.pendingChanges.Store(nodeID, changes)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// processBatchedChanges processes all pending batched changes
|
||||
func (b *LockFreeBatcher) processBatchedChanges() {
|
||||
b.batchMutex.Lock()
|
||||
defer b.batchMutex.Unlock()
|
||||
|
||||
if b.pendingChanges == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Process all pending changes
|
||||
b.pendingChanges.Range(func(nodeID types.NodeID, changes []change.ChangeSet) bool {
|
||||
if len(changes) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Send all batched changes for this node
|
||||
for _, c := range changes {
|
||||
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
|
||||
}
|
||||
|
||||
// Clear the pending changes for this node
|
||||
b.pendingChanges.Delete(nodeID)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// IsConnected is lock-free read.
|
||||
func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
|
||||
if val, ok := b.connected.Load(id); ok {
|
||||
// nil means connected
|
||||
return val == nil
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ConnectedMap returns a lock-free map of all connected nodes.
|
||||
func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
|
||||
ret := xsync.NewMap[types.NodeID, bool]()
|
||||
|
||||
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
|
||||
// nil means connected
|
||||
ret.Store(id, val == nil)
|
||||
return true
|
||||
})
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// MapResponseFromChange queues work to generate a map response and waits for the result.
|
||||
// This allows synchronous map generation using the same worker pool.
|
||||
func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error) {
|
||||
resultCh := make(chan workResult, 1)
|
||||
|
||||
// Queue the work with a result channel using the safe queueing method
|
||||
b.queueWork(work{c: c, nodeID: id, resultCh: resultCh})
|
||||
|
||||
// Wait for the result
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
return result.mapResponse, result.err
|
||||
case <-b.ctx.Done():
|
||||
return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id)
|
||||
}
|
||||
}
|
||||
|
||||
// connectionData holds the channel and connection parameters.
|
||||
type connectionData struct {
|
||||
c chan<- *tailcfg.MapResponse
|
||||
version tailcfg.CapabilityVersion
|
||||
closed atomic.Bool // Track if this connection has been closed
|
||||
}
|
||||
|
||||
// nodeConn described the node connection and its associated data.
|
||||
type nodeConn struct {
|
||||
id types.NodeID
|
||||
mapper *mapper
|
||||
|
||||
// Atomic pointer to connection data - allows lock-free updates
|
||||
connData atomic.Pointer[connectionData]
|
||||
|
||||
updateCount atomic.Int64
|
||||
}
|
||||
|
||||
func newNodeConn(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, mapper *mapper) *nodeConn {
|
||||
nc := &nodeConn{
|
||||
id: id,
|
||||
mapper: mapper,
|
||||
}
|
||||
|
||||
// Initialize connection data
|
||||
data := &connectionData{
|
||||
c: c,
|
||||
version: version,
|
||||
}
|
||||
nc.connData.Store(data)
|
||||
|
||||
return nc
|
||||
}
|
||||
|
||||
// updateConnection atomically updates connection parameters.
|
||||
func (nc *nodeConn) updateConnection(c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) {
|
||||
newData := &connectionData{
|
||||
c: c,
|
||||
version: version,
|
||||
}
|
||||
nc.connData.Store(newData)
|
||||
}
|
||||
|
||||
// matchesChannel checks if the given channel matches current connection.
|
||||
func (nc *nodeConn) matchesChannel(c chan<- *tailcfg.MapResponse) bool {
|
||||
data := nc.connData.Load()
|
||||
if data == nil {
|
||||
return false
|
||||
}
|
||||
// Compare channel pointers directly
|
||||
return data.c == c
|
||||
}
|
||||
|
||||
// compressAndVersion atomically reads connection settings.
|
||||
func (nc *nodeConn) version() tailcfg.CapabilityVersion {
|
||||
data := nc.connData.Load()
|
||||
if data == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return data.version
|
||||
}
|
||||
|
||||
func (nc *nodeConn) nodeID() types.NodeID {
|
||||
return nc.id
|
||||
}
|
||||
|
||||
func (nc *nodeConn) change(c change.ChangeSet) error {
|
||||
return handleNodeChange(nc, nc.mapper, c)
|
||||
}
|
||||
|
||||
// send sends data to the node's channel.
|
||||
// The node will pick it up and send it to the HTTP handler.
|
||||
func (nc *nodeConn) send(data *tailcfg.MapResponse) error {
|
||||
connData := nc.connData.Load()
|
||||
if connData == nil {
|
||||
return fmt.Errorf("node %d: no connection data", nc.id)
|
||||
}
|
||||
|
||||
// Check if connection has been closed
|
||||
if connData.closed.Load() {
|
||||
return fmt.Errorf("node %d: connection closed", nc.id)
|
||||
}
|
||||
|
||||
// TODO(kradalby): We might need some sort of timeout here if the client is not reading
|
||||
// the channel. That might mean that we are sending to a node that has gone offline, but
|
||||
// the channel is still open.
|
||||
connData.c <- data
|
||||
nc.updateCount.Add(1)
|
||||
return nil
|
||||
}
|
1977
hscontrol/mapper/batcher_test.go
Normal file
1977
hscontrol/mapper/batcher_test.go
Normal file
File diff suppressed because it is too large
Load Diff
259
hscontrol/mapper/builder.go
Normal file
259
hscontrol/mapper/builder.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/views"
|
||||
"tailscale.com/util/multierr"
|
||||
)
|
||||
|
||||
// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse
|
||||
type MapResponseBuilder struct {
|
||||
resp *tailcfg.MapResponse
|
||||
mapper *mapper
|
||||
nodeID types.NodeID
|
||||
capVer tailcfg.CapabilityVersion
|
||||
errs []error
|
||||
}
|
||||
|
||||
// NewMapResponseBuilder creates a new builder with basic fields set
|
||||
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
|
||||
now := time.Now()
|
||||
return &MapResponseBuilder{
|
||||
resp: &tailcfg.MapResponse{
|
||||
KeepAlive: false,
|
||||
ControlTime: &now,
|
||||
},
|
||||
mapper: m,
|
||||
nodeID: nodeID,
|
||||
errs: nil,
|
||||
}
|
||||
}
|
||||
|
||||
// addError adds an error to the builder's error list
|
||||
func (b *MapResponseBuilder) addError(err error) {
|
||||
if err != nil {
|
||||
b.errs = append(b.errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
// hasErrors returns true if the builder has accumulated any errors
|
||||
func (b *MapResponseBuilder) hasErrors() bool {
|
||||
return len(b.errs) > 0
|
||||
}
|
||||
|
||||
// WithCapabilityVersion sets the capability version for the response
|
||||
func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVersion) *MapResponseBuilder {
|
||||
b.capVer = capVer
|
||||
return b
|
||||
}
|
||||
|
||||
// WithSelfNode adds the requesting node to the response
|
||||
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
_, matchers := b.mapper.state.Filter()
|
||||
tailnode, err := tailNode(
|
||||
node.View(), b.capVer, b.mapper.state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
b.mapper.cfg)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.Node = tailnode
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDERPMap adds the DERP map to the response
|
||||
func (b *MapResponseBuilder) WithDERPMap() *MapResponseBuilder {
|
||||
b.resp.DERPMap = b.mapper.state.DERPMap()
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDomain adds the domain configuration
|
||||
func (b *MapResponseBuilder) WithDomain() *MapResponseBuilder {
|
||||
b.resp.Domain = b.mapper.cfg.Domain()
|
||||
return b
|
||||
}
|
||||
|
||||
// WithCollectServicesDisabled sets the collect services flag to false
|
||||
func (b *MapResponseBuilder) WithCollectServicesDisabled() *MapResponseBuilder {
|
||||
b.resp.CollectServices.Set(false)
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDebugConfig adds debug configuration
|
||||
// It disables log tailing if the mapper's LogTail is not enabled
|
||||
func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
||||
b.resp.Debug = &tailcfg.Debug{
|
||||
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// WithSSHPolicy adds SSH policy configuration for the requesting node
|
||||
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
sshPolicy, err := b.mapper.state.SSHPolicy(node.View())
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.SSHPolicy = sshPolicy
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDNSConfig adds DNS configuration for the requesting node
|
||||
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.DNSConfig = generateDNSConfig(b.mapper.cfg, node)
|
||||
return b
|
||||
}
|
||||
|
||||
// WithUserProfiles adds user profiles for the requesting node and given peers
|
||||
func (b *MapResponseBuilder) WithUserProfiles(peers types.Nodes) *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.UserProfiles = generateUserProfiles(node, peers)
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPacketFilters adds packet filter rules based on policy
|
||||
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
filter, _ := b.mapper.state.Filter()
|
||||
|
||||
// CapVer 81: 2023-11-17: MapResponse.PacketFilters (incremental packet filter updates)
|
||||
// Currently, we do not send incremental package filters, however using the
|
||||
// new PacketFilters field and "base" allows us to send a full update when we
|
||||
// have to send an empty list, avoiding the hack in the else block.
|
||||
b.resp.PacketFilters = map[string][]tailcfg.FilterRule{
|
||||
"base": policy.ReduceFilterRules(node.View(), filter),
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPeers adds full peer list with policy filtering (for full map response)
|
||||
func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder {
|
||||
|
||||
tailPeers, err := b.buildTailPeers(peers)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.Peers = tailPeers
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPeerChanges adds changed peers with policy filtering (for incremental updates)
|
||||
func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuilder {
|
||||
|
||||
tailPeers, err := b.buildTailPeers(peers)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.PeersChanged = tailPeers
|
||||
return b
|
||||
}
|
||||
|
||||
// buildTailPeers converts types.Nodes to []tailcfg.Node with policy filtering and sorting
|
||||
func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node, error) {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
filter, matchers := b.mapper.state.Filter()
|
||||
|
||||
// If there are filter rules present, see if there are any nodes that cannot
|
||||
// access each-other at all and remove them from the peers.
|
||||
var changedViews views.Slice[types.NodeView]
|
||||
if len(filter) > 0 {
|
||||
changedViews = policy.ReduceNodes(node.View(), peers.ViewSlice(), matchers)
|
||||
} else {
|
||||
changedViews = peers.ViewSlice()
|
||||
}
|
||||
|
||||
tailPeers, err := tailNodes(
|
||||
changedViews, b.capVer, b.mapper.state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
b.mapper.cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Peers is always returned sorted by Node.ID.
|
||||
sort.SliceStable(tailPeers, func(x, y int) bool {
|
||||
return tailPeers[x].ID < tailPeers[y].ID
|
||||
})
|
||||
|
||||
return tailPeers, nil
|
||||
}
|
||||
|
||||
// WithPeerChangedPatch adds peer change patches
|
||||
func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) *MapResponseBuilder {
|
||||
b.resp.PeersChangedPatch = changes
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPeersRemoved adds removed peer IDs
|
||||
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
|
||||
|
||||
var tailscaleIDs []tailcfg.NodeID
|
||||
for _, id := range removedIDs {
|
||||
tailscaleIDs = append(tailscaleIDs, id.NodeID())
|
||||
}
|
||||
b.resp.PeersRemoved = tailscaleIDs
|
||||
return b
|
||||
}
|
||||
|
||||
// Build finalizes the response and returns marshaled bytes
|
||||
func (b *MapResponseBuilder) Build(messages ...string) (*tailcfg.MapResponse, error) {
|
||||
if len(b.errs) > 0 {
|
||||
return nil, multierr.New(b.errs...)
|
||||
}
|
||||
if debugDumpMapResponsePath != "" {
|
||||
writeDebugMapResponse(b.resp, b.nodeID)
|
||||
}
|
||||
|
||||
return b.resp, nil
|
||||
}
|
347
hscontrol/mapper/builder_test.go
Normal file
347
hscontrol/mapper/builder_test.go
Normal file
@@ -0,0 +1,347 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func TestMapResponseBuilder_Basic(t *testing.T) {
|
||||
cfg := &types.Config{
|
||||
BaseDomain: "example.com",
|
||||
LogTail: types.LogTailConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID)
|
||||
|
||||
// Test basic builder creation
|
||||
assert.NotNil(t, builder)
|
||||
assert.Equal(t, nodeID, builder.nodeID)
|
||||
assert.NotNil(t, builder.resp)
|
||||
assert.False(t, builder.resp.KeepAlive)
|
||||
assert.NotNil(t, builder.resp.ControlTime)
|
||||
assert.WithinDuration(t, time.Now(), *builder.resp.ControlTime, time.Second)
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_WithCapabilityVersion(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
capVer := tailcfg.CapabilityVersion(42)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer)
|
||||
|
||||
assert.Equal(t, capVer, builder.capVer)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_WithDomain(t *testing.T) {
|
||||
domain := "test.example.com"
|
||||
cfg := &types.Config{
|
||||
ServerURL: "https://test.example.com",
|
||||
BaseDomain: domain,
|
||||
}
|
||||
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithDomain()
|
||||
|
||||
assert.Equal(t, domain, builder.resp.Domain)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_WithCollectServicesDisabled(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithCollectServicesDisabled()
|
||||
|
||||
value, isSet := builder.resp.CollectServices.Get()
|
||||
assert.True(t, isSet)
|
||||
assert.False(t, value)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_WithDebugConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
logTailEnabled bool
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "LogTail enabled",
|
||||
logTailEnabled: true,
|
||||
expected: false, // DisableLogTail should be false when LogTail is enabled
|
||||
},
|
||||
{
|
||||
name: "LogTail disabled",
|
||||
logTailEnabled: false,
|
||||
expected: true, // DisableLogTail should be true when LogTail is disabled
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &types.Config{
|
||||
LogTail: types.LogTailConfig{
|
||||
Enabled: tt.logTailEnabled,
|
||||
},
|
||||
}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugConfig()
|
||||
|
||||
require.NotNil(t, builder.resp.Debug)
|
||||
assert.Equal(t, tt.expected, builder.resp.Debug.DisableLogTail)
|
||||
assert.False(t, builder.hasErrors())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_WithPeerChangedPatch(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
changes := []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 123,
|
||||
DERPRegion: 1,
|
||||
},
|
||||
{
|
||||
NodeID: 456,
|
||||
DERPRegion: 2,
|
||||
},
|
||||
}
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeerChangedPatch(changes)
|
||||
|
||||
assert.Equal(t, changes, builder.resp.PeersChangedPatch)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_WithPeersRemoved(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
removedID1 := types.NodeID(123)
|
||||
removedID2 := types.NodeID(456)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeersRemoved(removedID1, removedID2)
|
||||
|
||||
expected := []tailcfg.NodeID{
|
||||
removedID1.NodeID(),
|
||||
removedID2.NodeID(),
|
||||
}
|
||||
assert.Equal(t, expected, builder.resp.PeersRemoved)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_ErrorHandling(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
// Simulate an error in the builder
|
||||
builder := m.NewMapResponseBuilder(nodeID)
|
||||
builder.addError(assert.AnError)
|
||||
|
||||
// All subsequent calls should continue to work and accumulate errors
|
||||
result := builder.
|
||||
WithDomain().
|
||||
WithCollectServicesDisabled().
|
||||
WithDebugConfig()
|
||||
|
||||
assert.True(t, result.hasErrors())
|
||||
assert.Len(t, result.errs, 1)
|
||||
assert.Equal(t, assert.AnError, result.errs[0])
|
||||
|
||||
// Build should return the error
|
||||
data, err := result.Build("none")
|
||||
assert.Nil(t, data)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_ChainedCalls(t *testing.T) {
|
||||
domain := "chained.example.com"
|
||||
cfg := &types.Config{
|
||||
ServerURL: "https://chained.example.com",
|
||||
BaseDomain: domain,
|
||||
LogTail: types.LogTailConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
}
|
||||
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
capVer := tailcfg.CapabilityVersion(99)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer).
|
||||
WithDomain().
|
||||
WithCollectServicesDisabled().
|
||||
WithDebugConfig()
|
||||
|
||||
// Verify all fields are set correctly
|
||||
assert.Equal(t, capVer, builder.capVer)
|
||||
assert.Equal(t, domain, builder.resp.Domain)
|
||||
value, isSet := builder.resp.CollectServices.Get()
|
||||
assert.True(t, isSet)
|
||||
assert.False(t, value)
|
||||
assert.NotNil(t, builder.resp.Debug)
|
||||
assert.True(t, builder.resp.Debug.DisableLogTail)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_MultipleWithPeersRemoved(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
removedID1 := types.NodeID(100)
|
||||
removedID2 := types.NodeID(200)
|
||||
|
||||
// Test calling WithPeersRemoved multiple times
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeersRemoved(removedID1).
|
||||
WithPeersRemoved(removedID2)
|
||||
|
||||
// Second call should overwrite the first
|
||||
expected := []tailcfg.NodeID{removedID2.NodeID()}
|
||||
assert.Equal(t, expected, builder.resp.PeersRemoved)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_EmptyPeerChangedPatch(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeerChangedPatch([]*tailcfg.PeerChange{})
|
||||
|
||||
assert.Empty(t, builder.resp.PeersChangedPatch)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_NilPeerChangedPatch(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeerChangedPatch(nil)
|
||||
|
||||
assert.Nil(t, builder.resp.PeersChangedPatch)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_MultipleErrors(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
// Create a builder and add multiple errors
|
||||
builder := m.NewMapResponseBuilder(nodeID)
|
||||
builder.addError(assert.AnError)
|
||||
builder.addError(assert.AnError)
|
||||
builder.addError(nil) // This should be ignored
|
||||
|
||||
// All subsequent calls should continue to work
|
||||
result := builder.
|
||||
WithDomain().
|
||||
WithCollectServicesDisabled()
|
||||
|
||||
assert.True(t, result.hasErrors())
|
||||
assert.Len(t, result.errs, 2) // nil error should be ignored
|
||||
|
||||
// Build should return a multierr
|
||||
data, err := result.Build("none")
|
||||
assert.Nil(t, data)
|
||||
assert.Error(t, err)
|
||||
|
||||
// The error should contain information about multiple errors
|
||||
assert.Contains(t, err.Error(), "multiple errors")
|
||||
}
|
@@ -1,7 +1,6 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
@@ -10,31 +9,21 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/smallzstd"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
const (
|
||||
nextDNSDoHPrefix = "https://dns.nextdns.io"
|
||||
reservedResponseHeaderSize = 4
|
||||
mapperIDLength = 8
|
||||
debugMapResponsePerm = 0o755
|
||||
nextDNSDoHPrefix = "https://dns.nextdns.io"
|
||||
mapperIDLength = 8
|
||||
debugMapResponsePerm = 0o755
|
||||
)
|
||||
|
||||
var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH")
|
||||
@@ -50,15 +39,13 @@ var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_
|
||||
// - Create a "minifier" that removes info not needed for the node
|
||||
// - some sort of batching, wait for 5 or 60 seconds before sending
|
||||
|
||||
type Mapper struct {
|
||||
type mapper struct {
|
||||
// Configuration
|
||||
state *state.State
|
||||
cfg *types.Config
|
||||
notif *notifier.Notifier
|
||||
state *state.State
|
||||
cfg *types.Config
|
||||
batcher Batcher
|
||||
|
||||
uid string
|
||||
created time.Time
|
||||
seq uint64
|
||||
}
|
||||
|
||||
type patch struct {
|
||||
@@ -66,41 +53,31 @@ type patch struct {
|
||||
change *tailcfg.PeerChange
|
||||
}
|
||||
|
||||
func NewMapper(
|
||||
state *state.State,
|
||||
func newMapper(
|
||||
cfg *types.Config,
|
||||
notif *notifier.Notifier,
|
||||
) *Mapper {
|
||||
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
||||
state *state.State,
|
||||
) *mapper {
|
||||
// uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
||||
|
||||
return &Mapper{
|
||||
return &mapper{
|
||||
state: state,
|
||||
cfg: cfg,
|
||||
notif: notif,
|
||||
|
||||
uid: uid,
|
||||
created: time.Now(),
|
||||
seq: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mapper) String() string {
|
||||
return fmt.Sprintf("Mapper: { seq: %d, uid: %s, created: %s }", m.seq, m.uid, m.created)
|
||||
}
|
||||
|
||||
func generateUserProfiles(
|
||||
node types.NodeView,
|
||||
peers views.Slice[types.NodeView],
|
||||
node *types.Node,
|
||||
peers types.Nodes,
|
||||
) []tailcfg.UserProfile {
|
||||
userMap := make(map[uint]*types.User)
|
||||
ids := make([]uint, 0, peers.Len()+1)
|
||||
user := node.User()
|
||||
userMap[user.ID] = &user
|
||||
ids = append(ids, user.ID)
|
||||
for _, peer := range peers.All() {
|
||||
peerUser := peer.User()
|
||||
userMap[peerUser.ID] = &peerUser
|
||||
ids = append(ids, peerUser.ID)
|
||||
ids := make([]uint, 0, len(userMap))
|
||||
userMap[node.User.ID] = &node.User
|
||||
ids = append(ids, node.User.ID)
|
||||
for _, peer := range peers {
|
||||
userMap[peer.User.ID] = &peer.User
|
||||
ids = append(ids, peer.User.ID)
|
||||
}
|
||||
|
||||
slices.Sort(ids)
|
||||
@@ -117,7 +94,7 @@ func generateUserProfiles(
|
||||
|
||||
func generateDNSConfig(
|
||||
cfg *types.Config,
|
||||
node types.NodeView,
|
||||
node *types.Node,
|
||||
) *tailcfg.DNSConfig {
|
||||
if cfg.TailcfgDNSConfig == nil {
|
||||
return nil
|
||||
@@ -137,17 +114,16 @@ func generateDNSConfig(
|
||||
//
|
||||
// This will produce a resolver like:
|
||||
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
|
||||
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
|
||||
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
|
||||
for _, resolver := range resolvers {
|
||||
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
|
||||
attrs := url.Values{
|
||||
"device_name": []string{node.Hostname()},
|
||||
"device_model": []string{node.Hostinfo().OS()},
|
||||
"device_name": []string{node.Hostname},
|
||||
"device_model": []string{node.Hostinfo.OS},
|
||||
}
|
||||
|
||||
nodeIPs := node.IPs()
|
||||
if len(nodeIPs) > 0 {
|
||||
attrs.Add("device_ip", nodeIPs[0].String())
|
||||
if len(node.IPs()) > 0 {
|
||||
attrs.Add("device_ip", node.IPs()[0].String())
|
||||
}
|
||||
|
||||
resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode())
|
||||
@@ -155,434 +131,151 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
|
||||
}
|
||||
}
|
||||
|
||||
// fullMapResponse creates a complete MapResponse for a node.
|
||||
// It is a separate function to make testing easier.
|
||||
func (m *Mapper) fullMapResponse(
|
||||
node types.NodeView,
|
||||
peers views.Slice[types.NodeView],
|
||||
// fullMapResponse returns a MapResponse for the given node.
|
||||
func (m *mapper) fullMapResponse(
|
||||
nodeID types.NodeID,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
messages ...string,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
resp, err := m.baseWithConfigMapResponse(node, capVer)
|
||||
peers, err := m.listPeers(nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = appendPeerChanges(
|
||||
resp,
|
||||
true, // full change
|
||||
m.state,
|
||||
node,
|
||||
capVer,
|
||||
peers,
|
||||
m.cfg,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer).
|
||||
WithSelfNode().
|
||||
WithDERPMap().
|
||||
WithDomain().
|
||||
WithCollectServicesDisabled().
|
||||
WithDebugConfig().
|
||||
WithSSHPolicy().
|
||||
WithDNSConfig().
|
||||
WithUserProfiles(peers).
|
||||
WithPacketFilters().
|
||||
WithPeers(peers).
|
||||
Build(messages...)
|
||||
}
|
||||
|
||||
// FullMapResponse returns a MapResponse for the given node.
|
||||
func (m *Mapper) FullMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node types.NodeView,
|
||||
messages ...string,
|
||||
) ([]byte, error) {
|
||||
peers, err := m.ListPeers(node.ID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := m.fullMapResponse(node, peers.ViewSlice(), mapRequest.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
|
||||
}
|
||||
|
||||
// ReadOnlyMapResponse returns a MapResponse for the given node.
|
||||
// Lite means that the peers has been omitted, this is intended
|
||||
// to be used to answer MapRequests with OmitPeers set to true.
|
||||
func (m *Mapper) ReadOnlyMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node types.NodeView,
|
||||
messages ...string,
|
||||
) ([]byte, error) {
|
||||
resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
|
||||
}
|
||||
|
||||
func (m *Mapper) KeepAliveResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node types.NodeView,
|
||||
) ([]byte, error) {
|
||||
resp := m.baseMapResponse()
|
||||
resp.KeepAlive = true
|
||||
|
||||
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
|
||||
}
|
||||
|
||||
func (m *Mapper) DERPMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node types.NodeView,
|
||||
derpMap *tailcfg.DERPMap,
|
||||
) ([]byte, error) {
|
||||
resp := m.baseMapResponse()
|
||||
resp.DERPMap = derpMap
|
||||
|
||||
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
|
||||
}
|
||||
|
||||
func (m *Mapper) PeerChangedResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node types.NodeView,
|
||||
changed map[types.NodeID]bool,
|
||||
patches []*tailcfg.PeerChange,
|
||||
messages ...string,
|
||||
) ([]byte, error) {
|
||||
var err error
|
||||
resp := m.baseMapResponse()
|
||||
|
||||
var removedIDs []tailcfg.NodeID
|
||||
var changedIDs []types.NodeID
|
||||
for nodeID, nodeChanged := range changed {
|
||||
if nodeChanged {
|
||||
if nodeID != node.ID() {
|
||||
changedIDs = append(changedIDs, nodeID)
|
||||
}
|
||||
} else {
|
||||
removedIDs = append(removedIDs, nodeID.NodeID())
|
||||
}
|
||||
}
|
||||
changedNodes := types.Nodes{}
|
||||
if len(changedIDs) > 0 {
|
||||
changedNodes, err = m.ListNodes(changedIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = appendPeerChanges(
|
||||
&resp,
|
||||
false, // partial change
|
||||
m.state,
|
||||
node,
|
||||
mapRequest.Version,
|
||||
changedNodes.ViewSlice(),
|
||||
m.cfg,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp.PeersRemoved = removedIDs
|
||||
|
||||
// Sending patches as a part of a PeersChanged response
|
||||
// is technically not suppose to be done, but they are
|
||||
// applied after the PeersChanged. The patch list
|
||||
// should _only_ contain Nodes that are not in the
|
||||
// PeersChanged or PeersRemoved list and the caller
|
||||
// should filter them out.
|
||||
//
|
||||
// From tailcfg docs:
|
||||
// These are applied after Peers* above, but in practice the
|
||||
// control server should only send these on their own, without
|
||||
// the Peers* fields also set.
|
||||
if patches != nil {
|
||||
resp.PeersChangedPatch = patches
|
||||
}
|
||||
|
||||
_, matchers := m.state.Filter()
|
||||
// Add the node itself, it might have changed, and particularly
|
||||
// if there are no patches or changes, this is a self update.
|
||||
tailnode, err := tailNode(
|
||||
node, mapRequest.Version, m.state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node, m.state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
m.cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Node = tailnode
|
||||
|
||||
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress, messages...)
|
||||
func (m *mapper) derpMapResponse(
|
||||
nodeID types.NodeID,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithDERPMap().
|
||||
Build()
|
||||
}
|
||||
|
||||
// PeerChangedPatchResponse creates a patch MapResponse with
|
||||
// incoming update from a state change.
|
||||
func (m *Mapper) PeerChangedPatchResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node types.NodeView,
|
||||
func (m *mapper) peerChangedPatchResponse(
|
||||
nodeID types.NodeID,
|
||||
changed []*tailcfg.PeerChange,
|
||||
) ([]byte, error) {
|
||||
resp := m.baseMapResponse()
|
||||
resp.PeersChangedPatch = changed
|
||||
|
||||
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
|
||||
}
|
||||
|
||||
func (m *Mapper) marshalMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
resp *tailcfg.MapResponse,
|
||||
node types.NodeView,
|
||||
compression string,
|
||||
messages ...string,
|
||||
) ([]byte, error) {
|
||||
atomic.AddUint64(&m.seq, 1)
|
||||
|
||||
jsonBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshalling map response: %w", err)
|
||||
}
|
||||
|
||||
if debugDumpMapResponsePath != "" {
|
||||
data := map[string]any{
|
||||
"Messages": messages,
|
||||
"MapRequest": mapRequest,
|
||||
"MapResponse": resp,
|
||||
}
|
||||
|
||||
responseType := "keepalive"
|
||||
|
||||
switch {
|
||||
case resp.Peers != nil && len(resp.Peers) > 0:
|
||||
responseType = "full"
|
||||
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil && resp.DERPMap == nil && !resp.KeepAlive:
|
||||
responseType = "self"
|
||||
case resp.PeersChanged != nil && len(resp.PeersChanged) > 0:
|
||||
responseType = "changed"
|
||||
case resp.PeersChangedPatch != nil && len(resp.PeersChangedPatch) > 0:
|
||||
responseType = "patch"
|
||||
case resp.PeersRemoved != nil && len(resp.PeersRemoved) > 0:
|
||||
responseType = "removed"
|
||||
}
|
||||
|
||||
body, err := json.MarshalIndent(data, "", " ")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshalling map response: %w", err)
|
||||
}
|
||||
|
||||
perms := fs.FileMode(debugMapResponsePerm)
|
||||
mPath := path.Join(debugDumpMapResponsePath, node.Hostname())
|
||||
err = os.MkdirAll(mPath, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
now := time.Now().Format("2006-01-02T15-04-05.999999999")
|
||||
|
||||
mapResponsePath := path.Join(
|
||||
mPath,
|
||||
fmt.Sprintf("%s-%s-%d-%s.json", now, m.uid, atomic.LoadUint64(&m.seq), responseType),
|
||||
)
|
||||
|
||||
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
|
||||
err = os.WriteFile(mapResponsePath, body, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
var respBody []byte
|
||||
if compression == util.ZstdCompression {
|
||||
respBody = zstdEncode(jsonBody)
|
||||
} else {
|
||||
respBody = jsonBody
|
||||
}
|
||||
|
||||
data := make([]byte, reservedResponseHeaderSize)
|
||||
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
|
||||
data = append(data, respBody...)
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func zstdEncode(in []byte) []byte {
|
||||
encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder)
|
||||
if !ok {
|
||||
panic("invalid type in sync pool")
|
||||
}
|
||||
out := encoder.EncodeAll(in, nil)
|
||||
_ = encoder.Close()
|
||||
zstdEncoderPool.Put(encoder)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
var zstdEncoderPool = &sync.Pool{
|
||||
New: func() any {
|
||||
encoder, err := smallzstd.NewEncoder(
|
||||
nil,
|
||||
zstd.WithEncoderLevel(zstd.SpeedFastest))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return encoder
|
||||
},
|
||||
}
|
||||
|
||||
// baseMapResponse returns a tailcfg.MapResponse with
|
||||
// KeepAlive false and ControlTime set to now.
|
||||
func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
|
||||
now := time.Now()
|
||||
|
||||
resp := tailcfg.MapResponse{
|
||||
KeepAlive: false,
|
||||
ControlTime: &now,
|
||||
// TODO(kradalby): Implement PingRequest?
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// baseWithConfigMapResponse returns a tailcfg.MapResponse struct
|
||||
// with the basic configuration from headscale set.
|
||||
// It is used in for bigger updates, such as full and lite, not
|
||||
// incremental.
|
||||
func (m *Mapper) baseWithConfigMapResponse(
|
||||
node types.NodeView,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
resp := m.baseMapResponse()
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithPeerChangedPatch(changed).
|
||||
Build()
|
||||
}
|
||||
|
||||
_, matchers := m.state.Filter()
|
||||
tailnode, err := tailNode(
|
||||
node, capVer, m.state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node, m.state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
m.cfg)
|
||||
// peerChangeResponse returns a MapResponse with changed or added nodes.
|
||||
func (m *mapper) peerChangeResponse(
|
||||
nodeID types.NodeID,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
changedNodeID types.NodeID,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
peers, err := m.listPeers(nodeID, changedNodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Node = tailnode
|
||||
|
||||
resp.DERPMap = m.state.DERPMap()
|
||||
|
||||
resp.Domain = m.cfg.Domain()
|
||||
|
||||
// Do not instruct clients to collect services we do not
|
||||
// support or do anything with them
|
||||
resp.CollectServices = "false"
|
||||
|
||||
resp.KeepAlive = false
|
||||
|
||||
resp.Debug = &tailcfg.Debug{
|
||||
DisableLogTail: !m.cfg.LogTail.Enabled,
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer).
|
||||
WithSelfNode().
|
||||
WithUserProfiles(peers).
|
||||
WithPeerChanges(peers).
|
||||
Build()
|
||||
}
|
||||
|
||||
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
|
||||
// peerRemovedResponse creates a MapResponse indicating that a peer has been removed.
|
||||
func (m *mapper) peerRemovedResponse(
|
||||
nodeID types.NodeID,
|
||||
removedNodeID types.NodeID,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithPeersRemoved(removedNodeID).
|
||||
Build()
|
||||
}
|
||||
|
||||
func writeDebugMapResponse(
|
||||
resp *tailcfg.MapResponse,
|
||||
nodeID types.NodeID,
|
||||
messages ...string,
|
||||
) {
|
||||
data := map[string]any{
|
||||
"Messages": messages,
|
||||
"MapResponse": resp,
|
||||
}
|
||||
|
||||
responseType := "keepalive"
|
||||
|
||||
switch {
|
||||
case len(resp.Peers) > 0:
|
||||
responseType = "full"
|
||||
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil && resp.DERPMap == nil && !resp.KeepAlive:
|
||||
responseType = "self"
|
||||
case len(resp.PeersChanged) > 0:
|
||||
responseType = "changed"
|
||||
case len(resp.PeersChangedPatch) > 0:
|
||||
responseType = "patch"
|
||||
case len(resp.PeersRemoved) > 0:
|
||||
responseType = "removed"
|
||||
}
|
||||
|
||||
body, err := json.MarshalIndent(data, "", " ")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
perms := fs.FileMode(debugMapResponsePerm)
|
||||
mPath := path.Join(debugDumpMapResponsePath, nodeID.String())
|
||||
err = os.MkdirAll(mPath, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
now := time.Now().Format("2006-01-02T15-04-05.999999999")
|
||||
|
||||
mapResponsePath := path.Join(
|
||||
mPath,
|
||||
fmt.Sprintf("%s-%s.json", now, responseType),
|
||||
)
|
||||
|
||||
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
|
||||
err = os.WriteFile(mapResponsePath, body, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// listPeers returns peers of node, regardless of any Policy or if the node is expired.
|
||||
// If no peer IDs are given, all peers are returned.
|
||||
// If at least one peer ID is given, only these peer nodes will be returned.
|
||||
func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||
func (m *mapper) listPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||
peers, err := m.state.ListPeers(nodeID, peerIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO(kradalby): Add back online via batcher. This was removed
|
||||
// to avoid a circular dependency between the mapper and the notification.
|
||||
for _, peer := range peers {
|
||||
online := m.notif.IsLikelyConnected(peer.ID)
|
||||
online := m.batcher.IsConnected(peer.ID)
|
||||
peer.IsOnline = &online
|
||||
}
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
// ListNodes queries the database for either all nodes if no parameters are given
|
||||
// or for the given nodes if at least one node ID is given as parameter.
|
||||
func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
nodes, err := m.state.ListNodes(nodeIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
online := m.notif.IsLikelyConnected(node.ID)
|
||||
node.IsOnline = &online
|
||||
}
|
||||
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
// routeFilterFunc is a function that takes a node ID and returns a list of
|
||||
// netip.Prefixes that are allowed for that node. It is used to filter routes
|
||||
// from the primary route manager to the node.
|
||||
type routeFilterFunc func(id types.NodeID) []netip.Prefix
|
||||
|
||||
// appendPeerChanges mutates a tailcfg.MapResponse with all the
|
||||
// necessary changes when peers have changed.
|
||||
func appendPeerChanges(
|
||||
resp *tailcfg.MapResponse,
|
||||
|
||||
fullChange bool,
|
||||
state *state.State,
|
||||
node types.NodeView,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
changed views.Slice[types.NodeView],
|
||||
cfg *types.Config,
|
||||
) error {
|
||||
filter, matchers := state.Filter()
|
||||
|
||||
sshPolicy, err := state.SSHPolicy(node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If there are filter rules present, see if there are any nodes that cannot
|
||||
// access each-other at all and remove them from the peers.
|
||||
var reducedChanged views.Slice[types.NodeView]
|
||||
if len(filter) > 0 {
|
||||
reducedChanged = policy.ReduceNodes(node, changed, matchers)
|
||||
} else {
|
||||
reducedChanged = changed
|
||||
}
|
||||
|
||||
profiles := generateUserProfiles(node, reducedChanged)
|
||||
|
||||
dnsConfig := generateDNSConfig(cfg, node)
|
||||
|
||||
tailPeers, err := tailNodes(
|
||||
reducedChanged, capVer, state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node, state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Peers is always returned sorted by Node.ID.
|
||||
sort.SliceStable(tailPeers, func(x, y int) bool {
|
||||
return tailPeers[x].ID < tailPeers[y].ID
|
||||
})
|
||||
|
||||
if fullChange {
|
||||
resp.Peers = tailPeers
|
||||
} else {
|
||||
resp.PeersChanged = tailPeers
|
||||
}
|
||||
resp.DNSConfig = dnsConfig
|
||||
resp.UserProfiles = profiles
|
||||
resp.SSHPolicy = sshPolicy
|
||||
|
||||
// CapVer 81: 2023-11-17: MapResponse.PacketFilters (incremental packet filter updates)
|
||||
// Currently, we do not send incremental package filters, however using the
|
||||
// new PacketFilters field and "base" allows us to send a full update when we
|
||||
// have to send an empty list, avoiding the hack in the else block.
|
||||
resp.PacketFilters = map[string][]tailcfg.FilterRule{
|
||||
"base": policy.ReduceFilterRules(node, filter),
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@ package mapper
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -70,7 +71,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
|
||||
&types.Config{
|
||||
TailcfgDNSConfig: &dnsConfigOrig,
|
||||
},
|
||||
nodeInShared1.View(),
|
||||
nodeInShared1,
|
||||
)
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {
|
||||
@@ -126,11 +127,8 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ
|
||||
// Filter peers by the provided IDs
|
||||
var filtered types.Nodes
|
||||
for _, peer := range m.peers {
|
||||
for _, id := range peerIDs {
|
||||
if peer.ID == id {
|
||||
filtered = append(filtered, peer)
|
||||
break
|
||||
}
|
||||
if slices.Contains(peerIDs, peer.ID) {
|
||||
filtered = append(filtered, peer)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,11 +150,8 @@ func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
// Filter nodes by the provided IDs
|
||||
var filtered types.Nodes
|
||||
for _, node := range m.nodes {
|
||||
for _, id := range nodeIDs {
|
||||
if node.ID == id {
|
||||
filtered = append(filtered, node)
|
||||
break
|
||||
}
|
||||
if slices.Contains(nodeIDs, node.ID) {
|
||||
filtered = append(filtered, node)
|
||||
}
|
||||
}
|
||||
|
||||
|
47
hscontrol/mapper/utils.go
Normal file
47
hscontrol/mapper/utils.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package mapper
|
||||
|
||||
import "tailscale.com/tailcfg"
|
||||
|
||||
// mergePatch takes the current patch and a newer patch
|
||||
// and override any field that has changed.
|
||||
func mergePatch(currPatch, newPatch *tailcfg.PeerChange) {
|
||||
if newPatch.DERPRegion != 0 {
|
||||
currPatch.DERPRegion = newPatch.DERPRegion
|
||||
}
|
||||
|
||||
if newPatch.Cap != 0 {
|
||||
currPatch.Cap = newPatch.Cap
|
||||
}
|
||||
|
||||
if newPatch.CapMap != nil {
|
||||
currPatch.CapMap = newPatch.CapMap
|
||||
}
|
||||
|
||||
if newPatch.Endpoints != nil {
|
||||
currPatch.Endpoints = newPatch.Endpoints
|
||||
}
|
||||
|
||||
if newPatch.Key != nil {
|
||||
currPatch.Key = newPatch.Key
|
||||
}
|
||||
|
||||
if newPatch.KeySignature != nil {
|
||||
currPatch.KeySignature = newPatch.KeySignature
|
||||
}
|
||||
|
||||
if newPatch.DiscoKey != nil {
|
||||
currPatch.DiscoKey = newPatch.DiscoKey
|
||||
}
|
||||
|
||||
if newPatch.Online != nil {
|
||||
currPatch.Online = newPatch.Online
|
||||
}
|
||||
|
||||
if newPatch.LastSeen != nil {
|
||||
currPatch.LastSeen = newPatch.LastSeen
|
||||
}
|
||||
|
||||
if newPatch.KeyExpiry != nil {
|
||||
currPatch.KeyExpiry = newPatch.KeyExpiry
|
||||
}
|
||||
}
|
@@ -221,7 +221,7 @@ func (ns *noiseServer) NoisePollNetMapHandler(
|
||||
|
||||
ns.nodeKey = nv.NodeKey()
|
||||
|
||||
sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, nv)
|
||||
sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, nv.AsStruct())
|
||||
sess.tracef("a node sending a MapRequest with Noise protocol")
|
||||
if !sess.isStreaming() {
|
||||
sess.serve()
|
||||
@@ -279,28 +279,33 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
||||
return
|
||||
}
|
||||
|
||||
respBody, err := json.Marshal(registerResponse)
|
||||
if err != nil {
|
||||
httpError(writer, err)
|
||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
|
||||
if err := json.NewEncoder(writer).Encode(registerResponse); err != nil {
|
||||
log.Error().Err(err).Msg("NoiseRegistrationHandler: failed to encode RegisterResponse")
|
||||
return
|
||||
}
|
||||
|
||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
writer.Write(respBody)
|
||||
// Ensure response is flushed to client
|
||||
if flusher, ok := writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// getAndValidateNode retrieves the node from the database using the NodeKey
|
||||
// and validates that it matches the MachineKey from the Noise session.
|
||||
func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) {
|
||||
nv, err := ns.headscale.state.GetNodeViewByNodeKey(mapRequest.NodeKey)
|
||||
node, err := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil)
|
||||
}
|
||||
return types.NodeView{}, err
|
||||
return types.NodeView{}, NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("lookup node: %s", err), nil)
|
||||
}
|
||||
|
||||
nv := node.View()
|
||||
|
||||
// Validate that the MachineKey in the Noise session matches the one associated with the NodeKey.
|
||||
if ns.machineKey != nv.MachineKey() {
|
||||
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node key in request does not match the one associated with this machine key", nil)
|
||||
|
@@ -1,68 +0,0 @@
|
||||
package notifier
|
||||
|
||||
import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
"tailscale.com/envknob"
|
||||
)
|
||||
|
||||
const prometheusNamespace = "headscale"
|
||||
|
||||
var debugHighCardinalityMetrics = envknob.Bool("HEADSCALE_DEBUG_HIGH_CARDINALITY_METRICS")
|
||||
|
||||
var notifierUpdateSent *prometheus.CounterVec
|
||||
|
||||
func init() {
|
||||
if debugHighCardinalityMetrics {
|
||||
notifierUpdateSent = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_update_sent_total",
|
||||
Help: "total count of update sent on nodes channel",
|
||||
}, []string{"status", "type", "trigger", "id"})
|
||||
} else {
|
||||
notifierUpdateSent = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_update_sent_total",
|
||||
Help: "total count of update sent on nodes channel",
|
||||
}, []string{"status", "type", "trigger"})
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
notifierWaitersForLock = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_waiters_for_lock",
|
||||
Help: "gauge of waiters for the notifier lock",
|
||||
}, []string{"type", "action"})
|
||||
notifierWaitForLock = promauto.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_wait_for_lock_seconds",
|
||||
Help: "histogram of time spent waiting for the notifier lock",
|
||||
Buckets: []float64{0.001, 0.01, 0.1, 0.3, 0.5, 1, 3, 5, 10},
|
||||
}, []string{"action"})
|
||||
notifierUpdateReceived = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_update_received_total",
|
||||
Help: "total count of updates received by notifier",
|
||||
}, []string{"type", "trigger"})
|
||||
notifierNodeUpdateChans = promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_open_channels_total",
|
||||
Help: "total count open channels in notifier",
|
||||
})
|
||||
notifierBatcherWaitersForLock = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_batcher_waiters_for_lock",
|
||||
Help: "gauge of waiters for the notifier batcher lock",
|
||||
}, []string{"type", "action"})
|
||||
notifierBatcherChanges = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_batcher_changes_pending",
|
||||
Help: "gauge of full changes pending in the notifier batcher",
|
||||
}, []string{})
|
||||
notifierBatcherPatches = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_batcher_patches_pending",
|
||||
Help: "gauge of patches pending in the notifier batcher",
|
||||
}, []string{})
|
||||
)
|
@@ -1,488 +0,0 @@
|
||||
package notifier
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sasha-s/go-deadlock"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/set"
|
||||
)
|
||||
|
||||
var (
|
||||
debugDeadlock = envknob.Bool("HEADSCALE_DEBUG_DEADLOCK")
|
||||
debugDeadlockTimeout = envknob.RegisterDuration("HEADSCALE_DEBUG_DEADLOCK_TIMEOUT")
|
||||
)
|
||||
|
||||
func init() {
|
||||
deadlock.Opts.Disable = !debugDeadlock
|
||||
if debugDeadlock {
|
||||
deadlock.Opts.DeadlockTimeout = debugDeadlockTimeout()
|
||||
deadlock.Opts.PrintAllCurrentGoroutines = true
|
||||
}
|
||||
}
|
||||
|
||||
type Notifier struct {
|
||||
l deadlock.Mutex
|
||||
nodes map[types.NodeID]chan<- types.StateUpdate
|
||||
connected *xsync.MapOf[types.NodeID, bool]
|
||||
b *batcher
|
||||
cfg *types.Config
|
||||
closed bool
|
||||
}
|
||||
|
||||
func NewNotifier(cfg *types.Config) *Notifier {
|
||||
n := &Notifier{
|
||||
nodes: make(map[types.NodeID]chan<- types.StateUpdate),
|
||||
connected: xsync.NewMapOf[types.NodeID, bool](),
|
||||
cfg: cfg,
|
||||
closed: false,
|
||||
}
|
||||
b := newBatcher(cfg.Tuning.BatchChangeDelay, n)
|
||||
n.b = b
|
||||
|
||||
go b.doWork()
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// Close stops the batcher and closes all channels.
|
||||
func (n *Notifier) Close() {
|
||||
notifierWaitersForLock.WithLabelValues("lock", "close").Inc()
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "close").Dec()
|
||||
|
||||
n.closed = true
|
||||
n.b.close()
|
||||
|
||||
// Close channels safely using the helper method
|
||||
for nodeID, c := range n.nodes {
|
||||
n.safeCloseChannel(nodeID, c)
|
||||
}
|
||||
|
||||
// Clear node map after closing channels
|
||||
n.nodes = make(map[types.NodeID]chan<- types.StateUpdate)
|
||||
}
|
||||
|
||||
// safeCloseChannel closes a channel and panic recovers if already closed.
|
||||
func (n *Notifier) safeCloseChannel(nodeID types.NodeID, c chan<- types.StateUpdate) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().
|
||||
Uint64("node.id", nodeID.Uint64()).
|
||||
Any("recover", r).
|
||||
Msg("recovered from panic when closing channel in Close()")
|
||||
}
|
||||
}()
|
||||
close(c)
|
||||
}
|
||||
|
||||
func (n *Notifier) tracef(nID types.NodeID, msg string, args ...any) {
|
||||
log.Trace().
|
||||
Uint64("node.id", nID.Uint64()).
|
||||
Int("open_chans", len(n.nodes)).Msgf(msg, args...)
|
||||
}
|
||||
|
||||
func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
|
||||
start := time.Now()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "add").Inc()
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "add").Dec()
|
||||
notifierWaitForLock.WithLabelValues("add").Observe(time.Since(start).Seconds())
|
||||
|
||||
if n.closed {
|
||||
return
|
||||
}
|
||||
|
||||
// If a channel exists, it means the node has opened a new
|
||||
// connection. Close the old channel and replace it.
|
||||
if curr, ok := n.nodes[nodeID]; ok {
|
||||
n.tracef(nodeID, "channel present, closing and replacing")
|
||||
// Use the safeCloseChannel helper in a goroutine to avoid deadlocks
|
||||
// if/when someone is waiting to send on this channel
|
||||
go func(ch chan<- types.StateUpdate) {
|
||||
n.safeCloseChannel(nodeID, ch)
|
||||
}(curr)
|
||||
}
|
||||
|
||||
n.nodes[nodeID] = c
|
||||
n.connected.Store(nodeID, true)
|
||||
|
||||
n.tracef(nodeID, "added new channel")
|
||||
notifierNodeUpdateChans.Inc()
|
||||
}
|
||||
|
||||
// RemoveNode removes a node and a given channel from the notifier.
|
||||
// It checks that the channel is the same as currently being updated
|
||||
// and ignores the removal if it is not.
|
||||
// RemoveNode reports if the node/chan was removed.
|
||||
func (n *Notifier) RemoveNode(nodeID types.NodeID, c chan<- types.StateUpdate) bool {
|
||||
start := time.Now()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "remove").Inc()
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "remove").Dec()
|
||||
notifierWaitForLock.WithLabelValues("remove").Observe(time.Since(start).Seconds())
|
||||
|
||||
if n.closed {
|
||||
return true
|
||||
}
|
||||
|
||||
if len(n.nodes) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// If the channel exist, but it does not belong
|
||||
// to the caller, ignore.
|
||||
if curr, ok := n.nodes[nodeID]; ok {
|
||||
if curr != c {
|
||||
n.tracef(nodeID, "channel has been replaced, not removing")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
delete(n.nodes, nodeID)
|
||||
n.connected.Store(nodeID, false)
|
||||
|
||||
n.tracef(nodeID, "removed channel")
|
||||
notifierNodeUpdateChans.Dec()
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// IsConnected reports if a node is connected to headscale and has a
|
||||
// poll session open.
|
||||
func (n *Notifier) IsConnected(nodeID types.NodeID) bool {
|
||||
notifierWaitersForLock.WithLabelValues("lock", "conncheck").Inc()
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "conncheck").Dec()
|
||||
|
||||
if val, ok := n.connected.Load(nodeID); ok {
|
||||
return val
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsLikelyConnected reports if a node is connected to headscale and has a
|
||||
// poll session open, but doesn't lock, so might be wrong.
|
||||
func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
|
||||
if val, ok := n.connected.Load(nodeID); ok {
|
||||
return val
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// LikelyConnectedMap returns a thread safe map of connected nodes.
|
||||
func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] {
|
||||
return n.connected
|
||||
}
|
||||
|
||||
func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) {
|
||||
n.NotifyWithIgnore(ctx, update)
|
||||
}
|
||||
|
||||
func (n *Notifier) NotifyWithIgnore(
|
||||
ctx context.Context,
|
||||
update types.StateUpdate,
|
||||
ignoreNodeIDs ...types.NodeID,
|
||||
) {
|
||||
if n.closed {
|
||||
return
|
||||
}
|
||||
|
||||
notifierUpdateReceived.WithLabelValues(update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
|
||||
n.b.addOrPassthrough(update)
|
||||
}
|
||||
|
||||
func (n *Notifier) NotifyByNodeID(
|
||||
ctx context.Context,
|
||||
update types.StateUpdate,
|
||||
nodeID types.NodeID,
|
||||
) {
|
||||
start := time.Now()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "notify").Inc()
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "notify").Dec()
|
||||
notifierWaitForLock.WithLabelValues("notify").Observe(time.Since(start).Seconds())
|
||||
|
||||
if n.closed {
|
||||
return
|
||||
}
|
||||
|
||||
if c, ok := n.nodes[nodeID]; ok {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Error().
|
||||
Err(ctx.Err()).
|
||||
Uint64("node.id", nodeID.Uint64()).
|
||||
Any("origin", types.NotifyOriginKey.Value(ctx)).
|
||||
Any("origin-hostname", types.NotifyHostnameKey.Value(ctx)).
|
||||
Msgf("update not sent, context cancelled")
|
||||
if debugHighCardinalityMetrics {
|
||||
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx), nodeID.String()).Inc()
|
||||
} else {
|
||||
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
|
||||
}
|
||||
|
||||
return
|
||||
case c <- update:
|
||||
n.tracef(nodeID, "update successfully sent on chan, origin: %s, origin-hostname: %s", ctx.Value("origin"), ctx.Value("hostname"))
|
||||
if debugHighCardinalityMetrics {
|
||||
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx), nodeID.String()).Inc()
|
||||
} else {
|
||||
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Notifier) sendAll(update types.StateUpdate) {
|
||||
start := time.Now()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "send-all").Inc()
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "send-all").Dec()
|
||||
notifierWaitForLock.WithLabelValues("send-all").Observe(time.Since(start).Seconds())
|
||||
|
||||
if n.closed {
|
||||
return
|
||||
}
|
||||
|
||||
for id, c := range n.nodes {
|
||||
// Whenever an update is sent to all nodes, there is a chance that the node
|
||||
// has disconnected and the goroutine that was supposed to consume the update
|
||||
// has shut down the channel and is waiting for the lock held here in RemoveNode.
|
||||
// This means that there is potential for a deadlock which would stop all updates
|
||||
// going out to clients. This timeout prevents that from happening by moving on to the
|
||||
// next node if the context is cancelled. After sendAll releases the lock, the add/remove
|
||||
// call will succeed and the update will go to the correct nodes on the next call.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), n.cfg.Tuning.NotifierSendTimeout)
|
||||
defer cancel()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Error().
|
||||
Err(ctx.Err()).
|
||||
Uint64("node.id", id.Uint64()).
|
||||
Msgf("update not sent, context cancelled")
|
||||
if debugHighCardinalityMetrics {
|
||||
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), "send-all", id.String()).Inc()
|
||||
} else {
|
||||
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), "send-all").Inc()
|
||||
}
|
||||
|
||||
return
|
||||
case c <- update:
|
||||
if debugHighCardinalityMetrics {
|
||||
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all", id.String()).Inc()
|
||||
} else {
|
||||
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all").Inc()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Notifier) String() string {
|
||||
notifierWaitersForLock.WithLabelValues("lock", "string").Inc()
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "string").Dec()
|
||||
|
||||
var b strings.Builder
|
||||
fmt.Fprintf(&b, "chans (%d):\n", len(n.nodes))
|
||||
|
||||
var keys []types.NodeID
|
||||
n.connected.Range(func(key types.NodeID, value bool) bool {
|
||||
keys = append(keys, key)
|
||||
return true
|
||||
})
|
||||
sort.Slice(keys, func(i, j int) bool {
|
||||
return keys[i] < keys[j]
|
||||
})
|
||||
|
||||
for _, key := range keys {
|
||||
fmt.Fprintf(&b, "\t%d: %p\n", key, n.nodes[key])
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
fmt.Fprintf(&b, "connected (%d):\n", len(n.nodes))
|
||||
|
||||
for _, key := range keys {
|
||||
val, _ := n.connected.Load(key)
|
||||
fmt.Fprintf(&b, "\t%d: %t\n", key, val)
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
type batcher struct {
|
||||
tick *time.Ticker
|
||||
|
||||
mu sync.Mutex
|
||||
|
||||
cancelCh chan struct{}
|
||||
|
||||
changedNodeIDs set.Slice[types.NodeID]
|
||||
nodesChanged bool
|
||||
patches map[types.NodeID]tailcfg.PeerChange
|
||||
patchesChanged bool
|
||||
|
||||
n *Notifier
|
||||
}
|
||||
|
||||
func newBatcher(batchTime time.Duration, n *Notifier) *batcher {
|
||||
return &batcher{
|
||||
tick: time.NewTicker(batchTime),
|
||||
cancelCh: make(chan struct{}),
|
||||
patches: make(map[types.NodeID]tailcfg.PeerChange),
|
||||
n: n,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *batcher) close() {
|
||||
b.cancelCh <- struct{}{}
|
||||
}
|
||||
|
||||
// addOrPassthrough adds the update to the batcher, if it is not a
|
||||
// type that is currently batched, it will be sent immediately.
|
||||
func (b *batcher) addOrPassthrough(update types.StateUpdate) {
|
||||
notifierBatcherWaitersForLock.WithLabelValues("lock", "add").Inc()
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
notifierBatcherWaitersForLock.WithLabelValues("lock", "add").Dec()
|
||||
|
||||
switch update.Type {
|
||||
case types.StatePeerChanged:
|
||||
b.changedNodeIDs.Add(update.ChangeNodes...)
|
||||
b.nodesChanged = true
|
||||
notifierBatcherChanges.WithLabelValues().Set(float64(b.changedNodeIDs.Len()))
|
||||
|
||||
case types.StatePeerChangedPatch:
|
||||
for _, newPatch := range update.ChangePatches {
|
||||
if curr, ok := b.patches[types.NodeID(newPatch.NodeID)]; ok {
|
||||
overwritePatch(&curr, newPatch)
|
||||
b.patches[types.NodeID(newPatch.NodeID)] = curr
|
||||
} else {
|
||||
b.patches[types.NodeID(newPatch.NodeID)] = *newPatch
|
||||
}
|
||||
}
|
||||
b.patchesChanged = true
|
||||
notifierBatcherPatches.WithLabelValues().Set(float64(len(b.patches)))
|
||||
|
||||
default:
|
||||
b.n.sendAll(update)
|
||||
}
|
||||
}
|
||||
|
||||
// flush sends all the accumulated patches to all
|
||||
// nodes in the notifier.
|
||||
func (b *batcher) flush() {
|
||||
notifierBatcherWaitersForLock.WithLabelValues("lock", "flush").Inc()
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
notifierBatcherWaitersForLock.WithLabelValues("lock", "flush").Dec()
|
||||
|
||||
if b.nodesChanged || b.patchesChanged {
|
||||
var patches []*tailcfg.PeerChange
|
||||
// If a node is getting a full update from a change
|
||||
// node update, then the patch can be dropped.
|
||||
for nodeID, patch := range b.patches {
|
||||
if b.changedNodeIDs.Contains(nodeID) {
|
||||
delete(b.patches, nodeID)
|
||||
} else {
|
||||
patches = append(patches, &patch)
|
||||
}
|
||||
}
|
||||
|
||||
changedNodes := b.changedNodeIDs.Slice().AsSlice()
|
||||
sort.Slice(changedNodes, func(i, j int) bool {
|
||||
return changedNodes[i] < changedNodes[j]
|
||||
})
|
||||
|
||||
if b.changedNodeIDs.Slice().Len() > 0 {
|
||||
update := types.UpdatePeerChanged(changedNodes...)
|
||||
|
||||
b.n.sendAll(update)
|
||||
}
|
||||
|
||||
if len(patches) > 0 {
|
||||
patchUpdate := types.UpdatePeerPatch(patches...)
|
||||
|
||||
b.n.sendAll(patchUpdate)
|
||||
}
|
||||
|
||||
b.changedNodeIDs = set.Slice[types.NodeID]{}
|
||||
notifierBatcherChanges.WithLabelValues().Set(0)
|
||||
b.nodesChanged = false
|
||||
b.patches = make(map[types.NodeID]tailcfg.PeerChange, len(b.patches))
|
||||
notifierBatcherPatches.WithLabelValues().Set(0)
|
||||
b.patchesChanged = false
|
||||
}
|
||||
}
|
||||
|
||||
func (b *batcher) doWork() {
|
||||
for {
|
||||
select {
|
||||
case <-b.cancelCh:
|
||||
return
|
||||
case <-b.tick.C:
|
||||
b.flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// overwritePatch takes the current patch and a newer patch
|
||||
// and override any field that has changed.
|
||||
func overwritePatch(currPatch, newPatch *tailcfg.PeerChange) {
|
||||
if newPatch.DERPRegion != 0 {
|
||||
currPatch.DERPRegion = newPatch.DERPRegion
|
||||
}
|
||||
|
||||
if newPatch.Cap != 0 {
|
||||
currPatch.Cap = newPatch.Cap
|
||||
}
|
||||
|
||||
if newPatch.CapMap != nil {
|
||||
currPatch.CapMap = newPatch.CapMap
|
||||
}
|
||||
|
||||
if newPatch.Endpoints != nil {
|
||||
currPatch.Endpoints = newPatch.Endpoints
|
||||
}
|
||||
|
||||
if newPatch.Key != nil {
|
||||
currPatch.Key = newPatch.Key
|
||||
}
|
||||
|
||||
if newPatch.KeySignature != nil {
|
||||
currPatch.KeySignature = newPatch.KeySignature
|
||||
}
|
||||
|
||||
if newPatch.DiscoKey != nil {
|
||||
currPatch.DiscoKey = newPatch.DiscoKey
|
||||
}
|
||||
|
||||
if newPatch.Online != nil {
|
||||
currPatch.Online = newPatch.Online
|
||||
}
|
||||
|
||||
if newPatch.LastSeen != nil {
|
||||
currPatch.LastSeen = newPatch.LastSeen
|
||||
}
|
||||
|
||||
if newPatch.KeyExpiry != nil {
|
||||
currPatch.KeyExpiry = newPatch.KeyExpiry
|
||||
}
|
||||
}
|
@@ -1,342 +0,0 @@
|
||||
package notifier
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func TestBatcher(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
updates []types.StateUpdate
|
||||
want []types.StateUpdate
|
||||
}{
|
||||
{
|
||||
name: "full-passthrough",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StateFullUpdate,
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StateFullUpdate,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "derp-passthrough",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StateDERPUpdated,
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StateDERPUpdated,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single-node-update",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: []types.NodeID{
|
||||
2,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: []types.NodeID{
|
||||
2,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "merge-node-update",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: []types.NodeID{
|
||||
2, 4,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: []types.NodeID{
|
||||
2, 3,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: []types.NodeID{
|
||||
2, 3, 4,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single-patch-update",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 2,
|
||||
DERPRegion: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 2,
|
||||
DERPRegion: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "merge-patch-to-same-node-update",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 2,
|
||||
DERPRegion: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 2,
|
||||
DERPRegion: 6,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 2,
|
||||
DERPRegion: 6,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "merge-patch-to-multiple-node-update",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 3,
|
||||
Endpoints: []netip.AddrPort{
|
||||
netip.MustParseAddrPort("1.1.1.1:9090"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 3,
|
||||
Endpoints: []netip.AddrPort{
|
||||
netip.MustParseAddrPort("1.1.1.1:9090"),
|
||||
netip.MustParseAddrPort("2.2.2.2:8080"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 4,
|
||||
DERPRegion: 6,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 4,
|
||||
Cap: tailcfg.CapabilityVersion(54),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 3,
|
||||
Endpoints: []netip.AddrPort{
|
||||
netip.MustParseAddrPort("1.1.1.1:9090"),
|
||||
netip.MustParseAddrPort("2.2.2.2:8080"),
|
||||
},
|
||||
},
|
||||
{
|
||||
NodeID: 4,
|
||||
DERPRegion: 6,
|
||||
Cap: tailcfg.CapabilityVersion(54),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
n := NewNotifier(&types.Config{
|
||||
Tuning: types.Tuning{
|
||||
// We will call flush manually for the tests,
|
||||
// so do not run the worker.
|
||||
BatchChangeDelay: time.Hour,
|
||||
|
||||
// Since we do not load the config, we won't get the
|
||||
// default, so set it manually so we dont time out
|
||||
// and have flakes.
|
||||
NotifierSendTimeout: time.Second,
|
||||
},
|
||||
})
|
||||
|
||||
ch := make(chan types.StateUpdate, 30)
|
||||
defer close(ch)
|
||||
n.AddNode(1, ch)
|
||||
defer n.RemoveNode(1, ch)
|
||||
|
||||
for _, u := range tt.updates {
|
||||
n.NotifyAll(t.Context(), u)
|
||||
}
|
||||
|
||||
n.b.flush()
|
||||
|
||||
var got []types.StateUpdate
|
||||
for len(ch) > 0 {
|
||||
out := <-ch
|
||||
got = append(got, out)
|
||||
}
|
||||
|
||||
// Make the inner order stable for comparison.
|
||||
for _, u := range got {
|
||||
slices.Sort(u.ChangeNodes)
|
||||
sort.Slice(u.ChangePatches, func(i, j int) bool {
|
||||
return u.ChangePatches[i].NodeID < u.ChangePatches[j].NodeID
|
||||
})
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
||||
t.Errorf("batcher() unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsLikelyConnectedRaceCondition tests for a race condition in IsLikelyConnected
|
||||
// Multiple goroutines calling AddNode and RemoveNode cause panics when trying to
|
||||
// close a channel that was already closed, which can happen when a node changes
|
||||
// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting.
|
||||
func TestIsLikelyConnectedRaceCondition(t *testing.T) {
|
||||
// mock config for the notifier
|
||||
cfg := &types.Config{
|
||||
Tuning: types.Tuning{
|
||||
NotifierSendTimeout: 1 * time.Second,
|
||||
BatchChangeDelay: 1 * time.Second,
|
||||
NodeMapSessionBufferedChanSize: 30,
|
||||
},
|
||||
}
|
||||
|
||||
notifier := NewNotifier(cfg)
|
||||
defer notifier.Close()
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
updateChan := make(chan types.StateUpdate, 10)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Number of goroutines to spawn for concurrent access
|
||||
concurrentAccessors := 100
|
||||
iterations := 100
|
||||
|
||||
// Add node to notifier
|
||||
notifier.AddNode(nodeID, updateChan)
|
||||
|
||||
// Track errors
|
||||
errChan := make(chan string, concurrentAccessors*iterations)
|
||||
|
||||
// Start goroutines to cause a race
|
||||
wg.Add(concurrentAccessors)
|
||||
for i := range concurrentAccessors {
|
||||
go func(routineID int) {
|
||||
defer wg.Done()
|
||||
|
||||
for range iterations {
|
||||
// Simulate race by having some goroutines check IsLikelyConnected
|
||||
// while others add/remove the node
|
||||
switch routineID % 3 {
|
||||
case 0:
|
||||
// This goroutine checks connection status
|
||||
isConnected := notifier.IsLikelyConnected(nodeID)
|
||||
if isConnected != true && isConnected != false {
|
||||
errChan <- fmt.Sprintf("Invalid connection status: %v", isConnected)
|
||||
}
|
||||
case 1:
|
||||
// This goroutine removes the node
|
||||
notifier.RemoveNode(nodeID, updateChan)
|
||||
default:
|
||||
// This goroutine adds the node back
|
||||
notifier.AddNode(nodeID, updateChan)
|
||||
}
|
||||
|
||||
// Small random delay to increase chance of races
|
||||
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
|
||||
// Collate errors
|
||||
var errors []string
|
||||
for err := range errChan {
|
||||
errors = append(errors, err)
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
t.Errorf("Detected %d race condition errors: %v", len(errors), errors)
|
||||
}
|
||||
}
|
@@ -16,9 +16,8 @@ import (
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/juanfont/headscale/hscontrol/db"
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/oauth2"
|
||||
@@ -56,11 +55,10 @@ type RegistrationInfo struct {
|
||||
}
|
||||
|
||||
type AuthProviderOIDC struct {
|
||||
h *Headscale
|
||||
serverURL string
|
||||
cfg *types.OIDCConfig
|
||||
state *state.State
|
||||
registrationCache *zcache.Cache[string, RegistrationInfo]
|
||||
notifier *notifier.Notifier
|
||||
|
||||
oidcProvider *oidc.Provider
|
||||
oauth2Config *oauth2.Config
|
||||
@@ -68,10 +66,9 @@ type AuthProviderOIDC struct {
|
||||
|
||||
func NewAuthProviderOIDC(
|
||||
ctx context.Context,
|
||||
h *Headscale,
|
||||
serverURL string,
|
||||
cfg *types.OIDCConfig,
|
||||
state *state.State,
|
||||
notif *notifier.Notifier,
|
||||
) (*AuthProviderOIDC, error) {
|
||||
var err error
|
||||
// grab oidc config if it hasn't been already
|
||||
@@ -94,11 +91,10 @@ func NewAuthProviderOIDC(
|
||||
)
|
||||
|
||||
return &AuthProviderOIDC{
|
||||
h: h,
|
||||
serverURL: serverURL,
|
||||
cfg: cfg,
|
||||
state: state,
|
||||
registrationCache: registrationCache,
|
||||
notifier: notif,
|
||||
|
||||
oidcProvider: oidcProvider,
|
||||
oauth2Config: oauth2Config,
|
||||
@@ -318,8 +314,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "oidc-user-created", user.Name)
|
||||
a.notifier.NotifyAll(ctx, types.UpdateFull())
|
||||
a.h.Change(change.PolicyChange())
|
||||
}
|
||||
|
||||
// TODO(kradalby): Is this comment right?
|
||||
@@ -360,8 +355,6 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
// Neither node nor machine key was found in the state cache meaning
|
||||
// that we could not reauth nor register the node.
|
||||
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func extractCodeAndStateParamFromRequest(
|
||||
@@ -490,12 +483,14 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||
var err error
|
||||
var newUser bool
|
||||
var policyChanged bool
|
||||
user, err = a.state.GetUserByOIDCIdentifier(claims.Identifier())
|
||||
user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier())
|
||||
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
||||
return nil, false, fmt.Errorf("creating or updating user: %w", err)
|
||||
}
|
||||
|
||||
// if the user is still not found, create a new empty user.
|
||||
// TODO(kradalby): This might cause us to not have an ID below which
|
||||
// is a problem.
|
||||
if user == nil {
|
||||
newUser = true
|
||||
user = &types.User{}
|
||||
@@ -504,12 +499,12 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||
user.FromClaim(claims)
|
||||
|
||||
if newUser {
|
||||
user, policyChanged, err = a.state.CreateUser(*user)
|
||||
user, policyChanged, err = a.h.state.CreateUser(*user)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("creating user: %w", err)
|
||||
}
|
||||
} else {
|
||||
_, policyChanged, err = a.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error {
|
||||
_, policyChanged, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error {
|
||||
*u = *user
|
||||
return nil
|
||||
})
|
||||
@@ -526,7 +521,7 @@ func (a *AuthProviderOIDC) handleRegistration(
|
||||
registrationID types.RegistrationID,
|
||||
expiry time.Time,
|
||||
) (bool, error) {
|
||||
node, newNode, err := a.state.HandleNodeFromAuthPath(
|
||||
node, nodeChange, err := a.h.state.HandleNodeFromAuthPath(
|
||||
registrationID,
|
||||
types.UserID(user.ID),
|
||||
&expiry,
|
||||
@@ -547,31 +542,20 @@ func (a *AuthProviderOIDC) handleRegistration(
|
||||
// ensure we send an update.
|
||||
// This works, but might be another good candidate for doing some sort of
|
||||
// eventbus.
|
||||
routesChanged := a.state.AutoApproveRoutes(node)
|
||||
_, policyChanged, err := a.state.SaveNode(node)
|
||||
_ = a.h.state.AutoApproveRoutes(node)
|
||||
_, policyChange, err := a.h.state.SaveNode(node)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("saving auto approved routes to node: %w", err)
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed (from SaveNode or route changes)
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "oidc-nodes-change", "all")
|
||||
a.notifier.NotifyAll(ctx, types.UpdateFull())
|
||||
// Policy updates are full and take precedence over node changes.
|
||||
if !policyChange.Empty() {
|
||||
a.h.Change(policyChange)
|
||||
} else {
|
||||
a.h.Change(nodeChange)
|
||||
}
|
||||
|
||||
if routesChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
|
||||
a.notifier.NotifyByNodeID(
|
||||
ctx,
|
||||
types.UpdateSelf(node.ID),
|
||||
node.ID,
|
||||
)
|
||||
|
||||
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
|
||||
a.notifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
||||
}
|
||||
|
||||
return newNode, nil
|
||||
return !nodeChange.Empty(), nil
|
||||
}
|
||||
|
||||
// TODO(kradalby):
|
||||
|
@@ -113,6 +113,17 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also check approved subnet routes - nodes should have access
|
||||
// to subnets they're approved to route traffic for.
|
||||
subnetRoutes := node.SubnetRoutes()
|
||||
|
||||
for _, subnetRoute := range subnetRoutes {
|
||||
if expanded.OverlapsPrefix(subnetRoute) {
|
||||
dests = append(dests, dest)
|
||||
continue DEST_LOOP
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(dests) > 0 {
|
||||
@@ -142,16 +153,23 @@ func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool {
|
||||
newApproved = append(newApproved, route)
|
||||
}
|
||||
}
|
||||
if newApproved != nil {
|
||||
newApproved = append(newApproved, node.ApprovedRoutes...)
|
||||
tsaddr.SortPrefixes(newApproved)
|
||||
newApproved = slices.Compact(newApproved)
|
||||
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
|
||||
|
||||
// Only modify ApprovedRoutes if we have new routes to approve.
|
||||
// This prevents clearing existing approved routes when nodes
|
||||
// temporarily don't have announced routes during policy changes.
|
||||
if len(newApproved) > 0 {
|
||||
combined := append(newApproved, node.ApprovedRoutes...)
|
||||
tsaddr.SortPrefixes(combined)
|
||||
combined = slices.Compact(combined)
|
||||
combined = lo.Filter(combined, func(route netip.Prefix, index int) bool {
|
||||
return route.IsValid()
|
||||
})
|
||||
node.ApprovedRoutes = newApproved
|
||||
|
||||
return true
|
||||
// Only update if the routes actually changed
|
||||
if !slices.Equal(node.ApprovedRoutes, combined) {
|
||||
node.ApprovedRoutes = combined
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
|
@@ -56,10 +56,13 @@ func (pol *Policy) compileFilterRules(
|
||||
}
|
||||
|
||||
if ips == nil {
|
||||
log.Debug().Msgf("destination resolved to nil ips: %v", dest)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, pref := range ips.Prefixes() {
|
||||
prefixes := ips.Prefixes()
|
||||
|
||||
for _, pref := range prefixes {
|
||||
for _, port := range dest.Ports {
|
||||
pr := tailcfg.NetPortRange{
|
||||
IP: pref.String(),
|
||||
@@ -103,6 +106,8 @@ func (pol *Policy) compileSSHPolicy(
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
log.Trace().Msgf("compiling SSH policy for node %q", node.Hostname())
|
||||
|
||||
var rules []*tailcfg.SSHRule
|
||||
|
||||
for index, rule := range pol.SSHs {
|
||||
@@ -137,7 +142,8 @@ func (pol *Policy) compileSSHPolicy(
|
||||
var principals []*tailcfg.SSHPrincipal
|
||||
srcIPs, err := rule.Sources.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msgf("resolving source ips")
|
||||
log.Trace().Err(err).Msgf("SSH policy compilation failed resolving source ips for rule %+v", rule)
|
||||
continue // Skip this rule if we can't resolve sources
|
||||
}
|
||||
|
||||
for addr := range util.IPSetAddrIter(srcIPs) {
|
||||
|
@@ -70,7 +70,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
// TODO(kradalby): This could potentially be optimized by only clearing the
|
||||
// policies for nodes that have changed. Particularly if the only difference is
|
||||
// that nodes has been added or removed.
|
||||
defer clear(pm.sshPolicyMap)
|
||||
clear(pm.sshPolicyMap)
|
||||
|
||||
filter, err := pm.pol.compileFilterRules(pm.users, pm.nodes)
|
||||
if err != nil {
|
||||
|
@@ -1730,7 +1730,7 @@ func (u SSHUser) MarshalJSON() ([]byte, error) {
|
||||
// In addition to unmarshalling, it will also validate the policy.
|
||||
// This is the only entrypoint of reading a policy from a file or other source.
|
||||
func unmarshalPolicy(b []byte) (*Policy, error) {
|
||||
if b == nil || len(b) == 0 {
|
||||
if len(b) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
@@ -2,20 +2,20 @@ package hscontrol
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sasha-s/go-deadlock"
|
||||
xslices "golang.org/x/exp/slices"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/zstdframe"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -31,18 +31,17 @@ type mapSession struct {
|
||||
req tailcfg.MapRequest
|
||||
ctx context.Context
|
||||
capVer tailcfg.CapabilityVersion
|
||||
mapper *mapper.Mapper
|
||||
|
||||
cancelChMu deadlock.Mutex
|
||||
|
||||
ch chan types.StateUpdate
|
||||
ch chan *tailcfg.MapResponse
|
||||
cancelCh chan struct{}
|
||||
cancelChOpen bool
|
||||
|
||||
keepAlive time.Duration
|
||||
keepAliveTicker *time.Ticker
|
||||
|
||||
node types.NodeView
|
||||
node *types.Node
|
||||
w http.ResponseWriter
|
||||
|
||||
warnf func(string, ...any)
|
||||
@@ -55,18 +54,9 @@ func (h *Headscale) newMapSession(
|
||||
ctx context.Context,
|
||||
req tailcfg.MapRequest,
|
||||
w http.ResponseWriter,
|
||||
nv types.NodeView,
|
||||
node *types.Node,
|
||||
) *mapSession {
|
||||
warnf, infof, tracef, errf := logPollFuncView(req, nv)
|
||||
|
||||
var updateChan chan types.StateUpdate
|
||||
if req.Stream {
|
||||
// Use a buffered channel in case a node is not fully ready
|
||||
// to receive a message to make sure we dont block the entire
|
||||
// notifier.
|
||||
updateChan = make(chan types.StateUpdate, h.cfg.Tuning.NodeMapSessionBufferedChanSize)
|
||||
updateChan <- types.UpdateFull()
|
||||
}
|
||||
warnf, infof, tracef, errf := logPollFunc(req, node)
|
||||
|
||||
ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond)
|
||||
|
||||
@@ -75,11 +65,10 @@ func (h *Headscale) newMapSession(
|
||||
ctx: ctx,
|
||||
req: req,
|
||||
w: w,
|
||||
node: nv,
|
||||
node: node,
|
||||
capVer: req.Version,
|
||||
mapper: h.mapper,
|
||||
|
||||
ch: updateChan,
|
||||
ch: make(chan *tailcfg.MapResponse, h.cfg.Tuning.NodeMapSessionBufferedChanSize),
|
||||
cancelCh: make(chan struct{}),
|
||||
cancelChOpen: true,
|
||||
|
||||
@@ -95,15 +84,11 @@ func (h *Headscale) newMapSession(
|
||||
}
|
||||
|
||||
func (m *mapSession) isStreaming() bool {
|
||||
return m.req.Stream && !m.req.ReadOnly
|
||||
return m.req.Stream
|
||||
}
|
||||
|
||||
func (m *mapSession) isEndpointUpdate() bool {
|
||||
return !m.req.Stream && !m.req.ReadOnly && m.req.OmitPeers
|
||||
}
|
||||
|
||||
func (m *mapSession) isReadOnlyUpdate() bool {
|
||||
return !m.req.Stream && m.req.OmitPeers && m.req.ReadOnly
|
||||
return !m.req.Stream && m.req.OmitPeers
|
||||
}
|
||||
|
||||
func (m *mapSession) resetKeepAlive() {
|
||||
@@ -112,25 +97,22 @@ func (m *mapSession) resetKeepAlive() {
|
||||
|
||||
func (m *mapSession) beforeServeLongPoll() {
|
||||
if m.node.IsEphemeral() {
|
||||
m.h.ephemeralGC.Cancel(m.node.ID())
|
||||
m.h.ephemeralGC.Cancel(m.node.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mapSession) afterServeLongPoll() {
|
||||
if m.node.IsEphemeral() {
|
||||
m.h.ephemeralGC.Schedule(m.node.ID(), m.h.cfg.EphemeralNodeInactivityTimeout)
|
||||
m.h.ephemeralGC.Schedule(m.node.ID, m.h.cfg.EphemeralNodeInactivityTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
// serve handles non-streaming requests.
|
||||
func (m *mapSession) serve() {
|
||||
// TODO(kradalby): A set todos to harden:
|
||||
// - func to tell the stream to die, readonly -> false, !stream && omitpeers -> false, true
|
||||
|
||||
// This is the mechanism where the node gives us information about its
|
||||
// current configuration.
|
||||
//
|
||||
// If OmitPeers is true, Stream is false, and ReadOnly is false,
|
||||
// If OmitPeers is true and Stream is false
|
||||
// then the server will let clients update their endpoints without
|
||||
// breaking existing long-polling (Stream == true) connections.
|
||||
// In this case, the server can omit the entire response; the client
|
||||
@@ -138,26 +120,18 @@ func (m *mapSession) serve() {
|
||||
//
|
||||
// This is what Tailscale calls a Lite update, the client ignores
|
||||
// the response and just wants a 200.
|
||||
// !req.stream && !req.ReadOnly && req.OmitPeers
|
||||
//
|
||||
// TODO(kradalby): remove ReadOnly when we only support capVer 68+
|
||||
// !req.stream && req.OmitPeers
|
||||
if m.isEndpointUpdate() {
|
||||
m.handleEndpointUpdate()
|
||||
c, err := m.h.state.UpdateNodeFromMapRequest(m.node, m.req)
|
||||
if err != nil {
|
||||
httpError(m.w, err)
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
m.h.Change(c)
|
||||
|
||||
// ReadOnly is whether the client just wants to fetch the
|
||||
// MapResponse, without updating their Endpoints. The
|
||||
// Endpoints field will be ignored and LastSeen will not be
|
||||
// updated and peers will not be notified of changes.
|
||||
//
|
||||
// The intended use is for clients to discover the DERP map at
|
||||
// start-up before their first real endpoint update.
|
||||
if m.isReadOnlyUpdate() {
|
||||
m.handleReadOnlyRequest()
|
||||
|
||||
return
|
||||
m.w.WriteHeader(http.StatusOK)
|
||||
mapResponseEndpointUpdates.WithLabelValues("ok").Inc()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,23 +149,15 @@ func (m *mapSession) serveLongPoll() {
|
||||
close(m.cancelCh)
|
||||
m.cancelChMu.Unlock()
|
||||
|
||||
// only update node status if the node channel was removed.
|
||||
// in principal, it will be removed, but the client rapidly
|
||||
// reconnects, the channel might be of another connection.
|
||||
// In that case, it is not closed and the node is still online.
|
||||
if m.h.nodeNotifier.RemoveNode(m.node.ID(), m.ch) {
|
||||
// TODO(kradalby): This can likely be made more effective, but likely most
|
||||
// nodes has access to the same routes, so it might not be a big deal.
|
||||
change, err := m.h.state.Disconnect(m.node.ID())
|
||||
if err != nil {
|
||||
m.errf(err, "Failed to disconnect node %s", m.node.Hostname())
|
||||
}
|
||||
|
||||
if change {
|
||||
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname())
|
||||
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
// TODO(kradalby): This can likely be made more effective, but likely most
|
||||
// nodes has access to the same routes, so it might not be a big deal.
|
||||
disconnectChange, err := m.h.state.Disconnect(m.node)
|
||||
if err != nil {
|
||||
m.errf(err, "Failed to disconnect node %s", m.node.Hostname)
|
||||
}
|
||||
m.h.Change(disconnectChange)
|
||||
|
||||
m.h.mapBatcher.RemoveNode(m.node.ID, m.ch, m.node.IsSubnetRouter())
|
||||
|
||||
m.afterServeLongPoll()
|
||||
m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch)
|
||||
@@ -201,21 +167,30 @@ func (m *mapSession) serveLongPoll() {
|
||||
m.h.pollNetMapStreamWG.Add(1)
|
||||
defer m.h.pollNetMapStreamWG.Done()
|
||||
|
||||
m.h.state.Connect(m.node.ID())
|
||||
|
||||
// Upgrade the writer to a ResponseController
|
||||
rc := http.NewResponseController(m.w)
|
||||
|
||||
// Longpolling will break if there is a write timeout,
|
||||
// so it needs to be disabled.
|
||||
rc.SetWriteDeadline(time.Time{})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname()))
|
||||
ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname))
|
||||
defer cancel()
|
||||
|
||||
m.keepAliveTicker = time.NewTicker(m.keepAlive)
|
||||
|
||||
m.h.nodeNotifier.AddNode(m.node.ID(), m.ch)
|
||||
// Add node to batcher BEFORE sending Connect change to prevent race condition
|
||||
// where the change is sent before the node is in the batcher's node map
|
||||
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.node.IsSubnetRouter(), m.capVer); err != nil {
|
||||
m.errf(err, "failed to add node to batcher")
|
||||
// Send empty response to client to fail fast for invalid/non-existent nodes
|
||||
select {
|
||||
case m.ch <- &tailcfg.MapResponse{}:
|
||||
default:
|
||||
// Channel might be closed
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Now send the Connect change - the batcher handles NodeCameOnline internally
|
||||
// but we still need to update routes and other state-level changes
|
||||
connectChange := m.h.state.Connect(m.node)
|
||||
if !connectChange.Empty() && connectChange.Change != change.NodeCameOnline {
|
||||
m.h.Change(connectChange)
|
||||
}
|
||||
|
||||
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
|
||||
|
||||
@@ -236,290 +211,94 @@ func (m *mapSession) serveLongPoll() {
|
||||
|
||||
// Consume updates sent to node
|
||||
case update, ok := <-m.ch:
|
||||
m.tracef("received update from channel, ok: %t", ok)
|
||||
if !ok {
|
||||
m.tracef("update channel closed, streaming session is likely being replaced")
|
||||
return
|
||||
}
|
||||
|
||||
// If the node has been removed from headscale, close the stream
|
||||
if slices.Contains(update.Removed, m.node.ID()) {
|
||||
m.tracef("node removed, closing stream")
|
||||
if err := m.writeMap(update); err != nil {
|
||||
m.errf(err, "cannot write update to client")
|
||||
return
|
||||
}
|
||||
|
||||
m.tracef("received stream update: %s %s", update.Type.String(), update.Message)
|
||||
mapResponseUpdateReceived.WithLabelValues(update.Type.String()).Inc()
|
||||
|
||||
var data []byte
|
||||
var err error
|
||||
var lastMessage string
|
||||
|
||||
// Ensure the node view is updated, for example, there
|
||||
// might have been a hostinfo update in a sidechannel
|
||||
// which contains data needed to generate a map response.
|
||||
m.node, err = m.h.state.GetNodeViewByID(m.node.ID())
|
||||
if err != nil {
|
||||
m.errf(err, "Could not get machine from db")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
updateType := "full"
|
||||
switch update.Type {
|
||||
case types.StateFullUpdate:
|
||||
m.tracef("Sending Full MapResponse")
|
||||
data, err = m.mapper.FullMapResponse(m.req, m.node, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
|
||||
case types.StatePeerChanged:
|
||||
changed := make(map[types.NodeID]bool, len(update.ChangeNodes))
|
||||
|
||||
for _, nodeID := range update.ChangeNodes {
|
||||
changed[nodeID] = true
|
||||
}
|
||||
|
||||
lastMessage = update.Message
|
||||
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
||||
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
|
||||
updateType = "change"
|
||||
|
||||
case types.StatePeerChangedPatch:
|
||||
m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
|
||||
data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches)
|
||||
updateType = "patch"
|
||||
case types.StatePeerRemoved:
|
||||
changed := make(map[types.NodeID]bool, len(update.Removed))
|
||||
|
||||
for _, nodeID := range update.Removed {
|
||||
changed[nodeID] = false
|
||||
}
|
||||
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
||||
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
|
||||
updateType = "remove"
|
||||
case types.StateSelfUpdate:
|
||||
lastMessage = update.Message
|
||||
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
||||
// create the map so an empty (self) update is sent
|
||||
data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, lastMessage)
|
||||
updateType = "remove"
|
||||
case types.StateDERPUpdated:
|
||||
m.tracef("Sending DERPUpdate MapResponse")
|
||||
data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.state.DERPMap())
|
||||
updateType = "derp"
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
m.errf(err, "Could not get the create map update")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Only send update if there is change
|
||||
if data != nil {
|
||||
startWrite := time.Now()
|
||||
_, err = m.w.Write(data)
|
||||
if err != nil {
|
||||
mapResponseSent.WithLabelValues("error", updateType).Inc()
|
||||
m.errf(err, "could not write the map response(%s), for mapSession: %p", update.Type.String(), m)
|
||||
return
|
||||
}
|
||||
|
||||
err = rc.Flush()
|
||||
if err != nil {
|
||||
mapResponseSent.WithLabelValues("error", updateType).Inc()
|
||||
m.errf(err, "flushing the map response to client, for mapSession: %p", m)
|
||||
return
|
||||
}
|
||||
|
||||
log.Trace().Str("node", m.node.Hostname()).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey().String()).Msg("finished writing mapresp to node")
|
||||
|
||||
if debugHighCardinalityMetrics {
|
||||
mapResponseLastSentSeconds.WithLabelValues(updateType, m.node.ID().String()).Set(float64(time.Now().Unix()))
|
||||
}
|
||||
mapResponseSent.WithLabelValues("ok", updateType).Inc()
|
||||
m.tracef("update sent")
|
||||
m.resetKeepAlive()
|
||||
}
|
||||
m.tracef("update sent")
|
||||
m.resetKeepAlive()
|
||||
|
||||
case <-m.keepAliveTicker.C:
|
||||
data, err := m.mapper.KeepAliveResponse(m.req, m.node)
|
||||
if err != nil {
|
||||
m.errf(err, "Error generating the keep alive msg")
|
||||
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
|
||||
return
|
||||
}
|
||||
_, err = m.w.Write(data)
|
||||
if err != nil {
|
||||
m.errf(err, "Cannot write keep alive message")
|
||||
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
|
||||
return
|
||||
}
|
||||
err = rc.Flush()
|
||||
if err != nil {
|
||||
m.errf(err, "flushing keep alive to client, for mapSession: %p", m)
|
||||
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
|
||||
if err := m.writeMap(&keepAlive); err != nil {
|
||||
m.errf(err, "cannot write keep alive")
|
||||
return
|
||||
}
|
||||
|
||||
if debugHighCardinalityMetrics {
|
||||
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID().String()).Set(float64(time.Now().Unix()))
|
||||
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix()))
|
||||
}
|
||||
mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mapSession) handleEndpointUpdate() {
|
||||
m.tracef("received endpoint update")
|
||||
|
||||
// Get fresh node state from database for accurate route calculations
|
||||
node, err := m.h.state.GetNodeByID(m.node.ID())
|
||||
// writeMap writes the map response to the client.
|
||||
// It handles compression if requested and any headers that need to be set.
|
||||
// It also handles flushing the response if the ResponseWriter
|
||||
// implements http.Flusher.
|
||||
func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error {
|
||||
jsonBody, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
m.errf(err, "Failed to get fresh node from database for endpoint update")
|
||||
http.Error(m.w, "", http.StatusInternalServerError)
|
||||
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
|
||||
return
|
||||
return fmt.Errorf("marshalling map response: %w", err)
|
||||
}
|
||||
|
||||
change := m.node.PeerChangeFromMapRequest(m.req)
|
||||
|
||||
online := m.h.nodeNotifier.IsLikelyConnected(m.node.ID())
|
||||
change.Online = &online
|
||||
|
||||
node.ApplyPeerChange(&change)
|
||||
|
||||
sendUpdate, routesChanged := hostInfoChanged(node.Hostinfo, m.req.Hostinfo)
|
||||
|
||||
// The node might not set NetInfo if it has not changed and if
|
||||
// the full HostInfo object is overwritten, the information is lost.
|
||||
// If there is no NetInfo, keep the previous one.
|
||||
// From 1.66 the client only sends it if changed:
|
||||
// https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2
|
||||
// TODO(kradalby): evaluate if we need better comparing of hostinfo
|
||||
// before we take the changes.
|
||||
if m.req.Hostinfo.NetInfo == nil && node.Hostinfo != nil {
|
||||
m.req.Hostinfo.NetInfo = node.Hostinfo.NetInfo
|
||||
}
|
||||
node.Hostinfo = m.req.Hostinfo
|
||||
|
||||
logTracePeerChange(node.Hostname, sendUpdate, &change)
|
||||
|
||||
// If there is no changes and nothing to save,
|
||||
// return early.
|
||||
if peerChangeEmpty(change) && !sendUpdate {
|
||||
mapResponseEndpointUpdates.WithLabelValues("noop").Inc()
|
||||
return
|
||||
if m.req.Compress == util.ZstdCompression {
|
||||
jsonBody = zstdframe.AppendEncode(nil, jsonBody, zstdframe.FastestCompression)
|
||||
}
|
||||
|
||||
// Auto approve any routes that have been defined in policy as
|
||||
// auto approved. Check if this actually changed the node.
|
||||
routesAutoApproved := m.h.state.AutoApproveRoutes(node)
|
||||
data := make([]byte, reservedResponseHeaderSize)
|
||||
binary.LittleEndian.PutUint32(data, uint32(len(jsonBody)))
|
||||
data = append(data, jsonBody...)
|
||||
|
||||
// Always update routes for connected nodes to handle reconnection scenarios
|
||||
// where routes need to be restored to the primary routes system
|
||||
routesToSet := node.SubnetRoutes()
|
||||
startWrite := time.Now()
|
||||
|
||||
if m.h.state.SetNodeRoutes(node.ID, routesToSet...) {
|
||||
ctx := types.NotifyCtx(m.ctx, "poll-primary-change", node.Hostname)
|
||||
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
} else if routesChanged {
|
||||
// Only send peer changed notification if routes actually changed
|
||||
ctx := types.NotifyCtx(m.ctx, "cli-approveroutes", node.Hostname)
|
||||
m.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
||||
|
||||
// TODO(kradalby): I am not sure if we need this?
|
||||
// Send an update to the node itself with to ensure it
|
||||
// has an updated packetfilter allowing the new route
|
||||
// if it is defined in the ACL.
|
||||
ctx = types.NotifyCtx(m.ctx, "poll-nodeupdate-self-hostinfochange", node.Hostname)
|
||||
m.h.nodeNotifier.NotifyByNodeID(
|
||||
ctx,
|
||||
types.UpdateSelf(node.ID),
|
||||
node.ID)
|
||||
_, err = m.w.Write(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If routes were auto-approved, we need to save the node to persist the changes
|
||||
if routesAutoApproved {
|
||||
if _, _, err := m.h.state.SaveNode(node); err != nil {
|
||||
m.errf(err, "Failed to save auto-approved routes to node")
|
||||
http.Error(m.w, "", http.StatusInternalServerError)
|
||||
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
|
||||
return
|
||||
if m.isStreaming() {
|
||||
if f, ok := m.w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
} else {
|
||||
m.errf(nil, "ResponseWriter does not implement http.Flusher, cannot flush")
|
||||
}
|
||||
}
|
||||
|
||||
// Check if there has been a change to Hostname and update them
|
||||
// in the database. Then send a Changed update
|
||||
// (containing the whole node object) to peers to inform about
|
||||
// the hostname change.
|
||||
node.ApplyHostnameFromHostInfo(m.req.Hostinfo)
|
||||
log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node")
|
||||
|
||||
_, policyChanged, err := m.h.state.SaveNode(node)
|
||||
if err != nil {
|
||||
m.errf(err, "Failed to persist/update node in the database")
|
||||
http.Error(m.w, "", http.StatusInternalServerError)
|
||||
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-policy", node.Hostname)
|
||||
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", node.Hostname)
|
||||
m.h.nodeNotifier.NotifyWithIgnore(
|
||||
ctx,
|
||||
types.UpdatePeerChanged(node.ID),
|
||||
node.ID,
|
||||
)
|
||||
|
||||
m.w.WriteHeader(http.StatusOK)
|
||||
mapResponseEndpointUpdates.WithLabelValues("ok").Inc()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mapSession) handleReadOnlyRequest() {
|
||||
m.tracef("Client asked for a lite update, responding without peers")
|
||||
|
||||
mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node)
|
||||
if err != nil {
|
||||
m.errf(err, "Failed to create MapResponse")
|
||||
http.Error(m.w, "", http.StatusInternalServerError)
|
||||
mapResponseReadOnly.WithLabelValues("error").Inc()
|
||||
return
|
||||
}
|
||||
|
||||
m.w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
m.w.WriteHeader(http.StatusOK)
|
||||
_, err = m.w.Write(mapResp)
|
||||
if err != nil {
|
||||
m.errf(err, "Failed to write response")
|
||||
mapResponseReadOnly.WithLabelValues("error").Inc()
|
||||
return
|
||||
}
|
||||
|
||||
m.w.WriteHeader(http.StatusOK)
|
||||
mapResponseReadOnly.WithLabelValues("ok").Inc()
|
||||
var keepAlive = tailcfg.MapResponse{
|
||||
KeepAlive: true,
|
||||
}
|
||||
|
||||
func logTracePeerChange(hostname string, hostinfoChange bool, change *tailcfg.PeerChange) {
|
||||
trace := log.Trace().Uint64("node.id", uint64(change.NodeID)).Str("hostname", hostname)
|
||||
func logTracePeerChange(hostname string, hostinfoChange bool, peerChange *tailcfg.PeerChange) {
|
||||
trace := log.Trace().Uint64("node.id", uint64(peerChange.NodeID)).Str("hostname", hostname)
|
||||
|
||||
if change.Key != nil {
|
||||
trace = trace.Str("node_key", change.Key.ShortString())
|
||||
if peerChange.Key != nil {
|
||||
trace = trace.Str("node_key", peerChange.Key.ShortString())
|
||||
}
|
||||
|
||||
if change.DiscoKey != nil {
|
||||
trace = trace.Str("disco_key", change.DiscoKey.ShortString())
|
||||
if peerChange.DiscoKey != nil {
|
||||
trace = trace.Str("disco_key", peerChange.DiscoKey.ShortString())
|
||||
}
|
||||
|
||||
if change.Online != nil {
|
||||
trace = trace.Bool("online", *change.Online)
|
||||
if peerChange.Online != nil {
|
||||
trace = trace.Bool("online", *peerChange.Online)
|
||||
}
|
||||
|
||||
if change.Endpoints != nil {
|
||||
eps := make([]string, len(change.Endpoints))
|
||||
for idx, ep := range change.Endpoints {
|
||||
if peerChange.Endpoints != nil {
|
||||
eps := make([]string, len(peerChange.Endpoints))
|
||||
for idx, ep := range peerChange.Endpoints {
|
||||
eps[idx] = ep.String()
|
||||
}
|
||||
|
||||
@@ -530,21 +309,11 @@ func logTracePeerChange(hostname string, hostinfoChange bool, change *tailcfg.Pe
|
||||
trace = trace.Bool("hostinfo_changed", hostinfoChange)
|
||||
}
|
||||
|
||||
if change.DERPRegion != 0 {
|
||||
trace = trace.Int("derp_region", change.DERPRegion)
|
||||
if peerChange.DERPRegion != 0 {
|
||||
trace = trace.Int("derp_region", peerChange.DERPRegion)
|
||||
}
|
||||
|
||||
trace.Time("last_seen", *change.LastSeen).Msg("PeerChange received")
|
||||
}
|
||||
|
||||
func peerChangeEmpty(chng tailcfg.PeerChange) bool {
|
||||
return chng.Key == nil &&
|
||||
chng.DiscoKey == nil &&
|
||||
chng.Online == nil &&
|
||||
chng.Endpoints == nil &&
|
||||
chng.DERPRegion == 0 &&
|
||||
chng.LastSeen == nil &&
|
||||
chng.KeyExpiry == nil
|
||||
trace.Time("last_seen", *peerChange.LastSeen).Msg("PeerChange received")
|
||||
}
|
||||
|
||||
func logPollFunc(
|
||||
@@ -554,7 +323,6 @@ func logPollFunc(
|
||||
return func(msg string, a ...any) {
|
||||
log.Warn().
|
||||
Caller().
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
@@ -564,7 +332,6 @@ func logPollFunc(
|
||||
func(msg string, a ...any) {
|
||||
log.Info().
|
||||
Caller().
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
@@ -574,7 +341,6 @@ func logPollFunc(
|
||||
func(msg string, a ...any) {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
@@ -584,7 +350,6 @@ func logPollFunc(
|
||||
func(err error, msg string, a ...any) {
|
||||
log.Error().
|
||||
Caller().
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
@@ -593,91 +358,3 @@ func logPollFunc(
|
||||
Msgf(msg, a...)
|
||||
}
|
||||
}
|
||||
|
||||
func logPollFuncView(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
nodeView types.NodeView,
|
||||
) (func(string, ...any), func(string, ...any), func(string, ...any), func(error, string, ...any)) {
|
||||
return func(msg string, a ...any) {
|
||||
log.Warn().
|
||||
Caller().
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", nodeView.ID().Uint64()).
|
||||
Str("node", nodeView.Hostname()).
|
||||
Msgf(msg, a...)
|
||||
},
|
||||
func(msg string, a ...any) {
|
||||
log.Info().
|
||||
Caller().
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", nodeView.ID().Uint64()).
|
||||
Str("node", nodeView.Hostname()).
|
||||
Msgf(msg, a...)
|
||||
},
|
||||
func(msg string, a ...any) {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", nodeView.ID().Uint64()).
|
||||
Str("node", nodeView.Hostname()).
|
||||
Msgf(msg, a...)
|
||||
},
|
||||
func(err error, msg string, a ...any) {
|
||||
log.Error().
|
||||
Caller().
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", nodeView.ID().Uint64()).
|
||||
Str("node", nodeView.Hostname()).
|
||||
Err(err).
|
||||
Msgf(msg, a...)
|
||||
}
|
||||
}
|
||||
|
||||
// hostInfoChanged reports if hostInfo has changed in two ways,
|
||||
// - first bool reports if an update needs to be sent to nodes
|
||||
// - second reports if there has been changes to routes
|
||||
// the caller can then use this info to save and update nodes
|
||||
// and routes as needed.
|
||||
func hostInfoChanged(old, new *tailcfg.Hostinfo) (bool, bool) {
|
||||
if old.Equal(new) {
|
||||
return false, false
|
||||
}
|
||||
|
||||
if old == nil && new != nil {
|
||||
return true, true
|
||||
}
|
||||
|
||||
// Routes
|
||||
oldRoutes := make([]netip.Prefix, 0)
|
||||
if old != nil {
|
||||
oldRoutes = old.RoutableIPs
|
||||
}
|
||||
newRoutes := new.RoutableIPs
|
||||
|
||||
tsaddr.SortPrefixes(oldRoutes)
|
||||
tsaddr.SortPrefixes(newRoutes)
|
||||
|
||||
if !xslices.Equal(oldRoutes, newRoutes) {
|
||||
return true, true
|
||||
}
|
||||
|
||||
// Services is mostly useful for discovery and not critical,
|
||||
// except for peerapi, which is how nodes talk to each other.
|
||||
// If peerapi was not part of the initial mapresponse, we
|
||||
// need to make sure its sent out later as it is needed for
|
||||
// Taildrop.
|
||||
// TODO(kradalby): Length comparison is a bit naive, replace.
|
||||
if len(old.Services) != len(new.Services) {
|
||||
return true, false
|
||||
}
|
||||
|
||||
return false, false
|
||||
}
|
||||
|
@@ -17,10 +17,13 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"github.com/juanfont/headscale/hscontrol/routes"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sasha-s/go-deadlock"
|
||||
xslices "golang.org/x/exp/slices"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
@@ -46,12 +49,6 @@ type State struct {
|
||||
// cfg holds the current Headscale configuration
|
||||
cfg *types.Config
|
||||
|
||||
// in-memory data, protected by mu
|
||||
// nodes contains the current set of registered nodes
|
||||
nodes types.Nodes
|
||||
// users contains the current set of users/namespaces
|
||||
users types.Users
|
||||
|
||||
// subsystem keeping state
|
||||
// db provides persistent storage and database operations
|
||||
db *hsdb.HSDatabase
|
||||
@@ -113,9 +110,6 @@ func NewState(cfg *types.Config) (*State, error) {
|
||||
return &State{
|
||||
cfg: cfg,
|
||||
|
||||
nodes: nodes,
|
||||
users: users,
|
||||
|
||||
db: db,
|
||||
ipAlloc: ipAlloc,
|
||||
// TODO(kradalby): Update DERPMap
|
||||
@@ -215,6 +209,7 @@ func (s *State) CreateUser(user types.User) (*types.User, bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
|
||||
if err := s.db.DB.Save(&user).Error; err != nil {
|
||||
return nil, false, fmt.Errorf("creating user: %w", err)
|
||||
}
|
||||
@@ -226,6 +221,18 @@ func (s *State) CreateUser(user types.User) (*types.User, bool, error) {
|
||||
return &user, false, fmt.Errorf("failed to update policy manager after user creation: %w", err)
|
||||
}
|
||||
|
||||
// Even if the policy manager doesn't detect a filter change, SSH policies
|
||||
// might now be resolvable when they weren't before. If there are existing
|
||||
// nodes, we should send a policy change to ensure they get updated SSH policies.
|
||||
if !policyChanged {
|
||||
nodes, err := s.ListNodes()
|
||||
if err == nil && len(nodes) > 0 {
|
||||
policyChanged = true
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().Str("user", user.Name).Bool("policyChanged", policyChanged).Msg("User created, policy manager updated")
|
||||
|
||||
// TODO(kradalby): implement the user in-memory cache
|
||||
|
||||
return &user, policyChanged, nil
|
||||
@@ -329,7 +336,7 @@ func (s *State) CreateNode(node *types.Node) (*types.Node, bool, error) {
|
||||
}
|
||||
|
||||
// updateNodeTx performs a database transaction to update a node and refresh the policy manager.
|
||||
func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) (*types.Node, bool, error) {
|
||||
func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) (*types.Node, change.ChangeSet, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
@@ -350,72 +357,100 @@ func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) err
|
||||
return node, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
return nil, change.EmptySet, err
|
||||
}
|
||||
|
||||
// Check if policy manager needs updating
|
||||
policyChanged, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return node, false, fmt.Errorf("failed to update policy manager after node update: %w", err)
|
||||
return node, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): implement the node in-memory cache
|
||||
|
||||
return node, policyChanged, nil
|
||||
var c change.ChangeSet
|
||||
if policyChanged {
|
||||
c = change.PolicyChange()
|
||||
} else {
|
||||
// Basic node change without specific details since this is a generic update
|
||||
c = change.NodeAdded(node.ID)
|
||||
}
|
||||
|
||||
return node, c, nil
|
||||
}
|
||||
|
||||
// SaveNode persists an existing node to the database and updates the policy manager.
|
||||
func (s *State) SaveNode(node *types.Node) (*types.Node, bool, error) {
|
||||
func (s *State) SaveNode(node *types.Node) (*types.Node, change.ChangeSet, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if err := s.db.DB.Save(node).Error; err != nil {
|
||||
return nil, false, fmt.Errorf("saving node: %w", err)
|
||||
return nil, change.EmptySet, fmt.Errorf("saving node: %w", err)
|
||||
}
|
||||
|
||||
// Check if policy manager needs updating
|
||||
policyChanged, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return node, false, fmt.Errorf("failed to update policy manager after node save: %w", err)
|
||||
return node, change.EmptySet, fmt.Errorf("failed to update policy manager after node save: %w", err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): implement the node in-memory cache
|
||||
|
||||
return node, policyChanged, nil
|
||||
if policyChanged {
|
||||
return node, change.PolicyChange(), nil
|
||||
}
|
||||
|
||||
return node, change.EmptySet, nil
|
||||
}
|
||||
|
||||
// DeleteNode permanently removes a node and cleans up associated resources.
|
||||
// Returns whether policies changed and any error. This operation is irreversible.
|
||||
func (s *State) DeleteNode(node *types.Node) (bool, error) {
|
||||
func (s *State) DeleteNode(node *types.Node) (change.ChangeSet, error) {
|
||||
err := s.db.DeleteNode(node)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return change.EmptySet, err
|
||||
}
|
||||
|
||||
c := change.NodeRemoved(node.ID)
|
||||
|
||||
// Check if policy manager needs updating after node deletion
|
||||
policyChanged, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to update policy manager after node deletion: %w", err)
|
||||
return change.EmptySet, fmt.Errorf("failed to update policy manager after node deletion: %w", err)
|
||||
}
|
||||
|
||||
return policyChanged, nil
|
||||
if policyChanged {
|
||||
c = change.PolicyChange()
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (s *State) Connect(id types.NodeID) {
|
||||
func (s *State) Connect(node *types.Node) change.ChangeSet {
|
||||
c := change.NodeOnline(node.ID)
|
||||
routeChange := s.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...)
|
||||
|
||||
if routeChange {
|
||||
c = change.NodeAdded(node.ID)
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (s *State) Disconnect(id types.NodeID) (bool, error) {
|
||||
// TODO(kradalby): This node should update the in memory state
|
||||
_, polChanged, err := s.SetLastSeen(id, time.Now())
|
||||
func (s *State) Disconnect(node *types.Node) (change.ChangeSet, error) {
|
||||
c := change.NodeOffline(node.ID)
|
||||
|
||||
_, _, err := s.SetLastSeen(node.ID, time.Now())
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("disconnecting node: %w", err)
|
||||
return c, fmt.Errorf("disconnecting node: %w", err)
|
||||
}
|
||||
|
||||
changed := s.primaryRoutes.SetRoutes(id)
|
||||
if routeChange := s.primaryRoutes.SetRoutes(node.ID); routeChange {
|
||||
c = change.PolicyChange()
|
||||
}
|
||||
|
||||
// TODO(kradalby): the returned change should be more nuanced allowing us to
|
||||
// send more directed updates.
|
||||
return changed || polChanged, nil
|
||||
// TODO(kradalby): This node should update the in memory state
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// GetNodeByID retrieves a node by ID.
|
||||
@@ -475,45 +510,93 @@ func (s *State) ListEphemeralNodes() (types.Nodes, error) {
|
||||
}
|
||||
|
||||
// SetNodeExpiry updates the expiration time for a node.
|
||||
func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (*types.Node, bool, error) {
|
||||
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (*types.Node, change.ChangeSet, error) {
|
||||
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.NodeSetExpiry(tx, nodeID, expiry)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, fmt.Errorf("setting node expiry: %w", err)
|
||||
}
|
||||
|
||||
if !c.IsFull() {
|
||||
c = change.KeyExpiry(nodeID)
|
||||
}
|
||||
|
||||
return n, c, nil
|
||||
}
|
||||
|
||||
// SetNodeTags assigns tags to a node for use in access control policies.
|
||||
func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (*types.Node, bool, error) {
|
||||
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (*types.Node, change.ChangeSet, error) {
|
||||
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.SetTags(tx, nodeID, tags)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, fmt.Errorf("setting node tags: %w", err)
|
||||
}
|
||||
|
||||
if !c.IsFull() {
|
||||
c = change.NodeAdded(nodeID)
|
||||
}
|
||||
|
||||
return n, c, nil
|
||||
}
|
||||
|
||||
// SetApprovedRoutes sets the network routes that a node is approved to advertise.
|
||||
func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (*types.Node, bool, error) {
|
||||
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (*types.Node, change.ChangeSet, error) {
|
||||
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.SetApprovedRoutes(tx, nodeID, routes)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, fmt.Errorf("setting approved routes: %w", err)
|
||||
}
|
||||
|
||||
// Update primary routes after changing approved routes
|
||||
routeChange := s.primaryRoutes.SetRoutes(nodeID, n.SubnetRoutes()...)
|
||||
|
||||
if routeChange || !c.IsFull() {
|
||||
c = change.PolicyChange()
|
||||
}
|
||||
|
||||
return n, c, nil
|
||||
}
|
||||
|
||||
// RenameNode changes the display name of a node.
|
||||
func (s *State) RenameNode(nodeID types.NodeID, newName string) (*types.Node, bool, error) {
|
||||
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
func (s *State) RenameNode(nodeID types.NodeID, newName string) (*types.Node, change.ChangeSet, error) {
|
||||
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.RenameNode(tx, nodeID, newName)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, fmt.Errorf("renaming node: %w", err)
|
||||
}
|
||||
|
||||
if !c.IsFull() {
|
||||
c = change.NodeAdded(nodeID)
|
||||
}
|
||||
|
||||
return n, c, nil
|
||||
}
|
||||
|
||||
// SetLastSeen updates when a node was last seen, used for connectivity monitoring.
|
||||
func (s *State) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) (*types.Node, bool, error) {
|
||||
func (s *State) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) (*types.Node, change.ChangeSet, error) {
|
||||
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.SetLastSeen(tx, nodeID, lastSeen)
|
||||
})
|
||||
}
|
||||
|
||||
// AssignNodeToUser transfers a node to a different user.
|
||||
func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (*types.Node, bool, error) {
|
||||
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (*types.Node, change.ChangeSet, error) {
|
||||
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.AssignNodeToUser(tx, nodeID, userID)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, fmt.Errorf("assigning node to user: %w", err)
|
||||
}
|
||||
|
||||
if !c.IsFull() {
|
||||
c = change.NodeAdded(nodeID)
|
||||
}
|
||||
|
||||
return n, c, nil
|
||||
}
|
||||
|
||||
// BackfillNodeIPs assigns IP addresses to nodes that don't have them.
|
||||
@@ -523,7 +606,7 @@ func (s *State) BackfillNodeIPs() ([]string, error) {
|
||||
|
||||
// ExpireExpiredNodes finds and processes expired nodes since the last check.
|
||||
// Returns next check time, state update with expired nodes, and whether any were found.
|
||||
func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, types.StateUpdate, bool) {
|
||||
func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.ChangeSet, bool) {
|
||||
return hsdb.ExpireExpiredNodes(s.db.DB, lastCheck)
|
||||
}
|
||||
|
||||
@@ -568,8 +651,14 @@ func (s *State) SetPolicyInDB(data string) (*types.Policy, error) {
|
||||
}
|
||||
|
||||
// SetNodeRoutes sets the primary routes for a node.
|
||||
func (s *State) SetNodeRoutes(nodeID types.NodeID, routes ...netip.Prefix) bool {
|
||||
return s.primaryRoutes.SetRoutes(nodeID, routes...)
|
||||
func (s *State) SetNodeRoutes(nodeID types.NodeID, routes ...netip.Prefix) change.ChangeSet {
|
||||
if s.primaryRoutes.SetRoutes(nodeID, routes...) {
|
||||
// Route changes affect packet filters for all nodes, so trigger a policy change
|
||||
// to ensure filters are regenerated across the entire network
|
||||
return change.PolicyChange()
|
||||
}
|
||||
|
||||
return change.EmptySet
|
||||
}
|
||||
|
||||
// GetNodePrimaryRoutes returns the primary routes for a node.
|
||||
@@ -653,10 +742,10 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
userID types.UserID,
|
||||
expiry *time.Time,
|
||||
registrationMethod string,
|
||||
) (*types.Node, bool, error) {
|
||||
) (*types.Node, change.ChangeSet, error) {
|
||||
ipv4, ipv6, err := s.ipAlloc.Next()
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
return nil, change.EmptySet, err
|
||||
}
|
||||
|
||||
return s.db.HandleNodeFromAuthPath(
|
||||
@@ -672,12 +761,15 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
func (s *State) HandleNodeFromPreAuthKey(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*types.Node, bool, error) {
|
||||
) (*types.Node, change.ChangeSet, bool, error) {
|
||||
pak, err := s.GetPreAuthKey(regReq.Auth.AuthKey)
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, false, err
|
||||
}
|
||||
|
||||
err = pak.Validate()
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
return nil, change.EmptySet, false, err
|
||||
}
|
||||
|
||||
nodeToRegister := types.Node{
|
||||
@@ -698,22 +790,13 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
AuthKeyID: &pak.ID,
|
||||
}
|
||||
|
||||
// For auth key registration, ensure we don't keep an expired node
|
||||
// This is especially important for re-registration after logout
|
||||
if !regReq.Expiry.IsZero() && regReq.Expiry.After(time.Now()) {
|
||||
if !regReq.Expiry.IsZero() {
|
||||
nodeToRegister.Expiry = ®Req.Expiry
|
||||
} else if !regReq.Expiry.IsZero() {
|
||||
// If client is sending an expired time (e.g., after logout),
|
||||
// don't set expiry so the node won't be considered expired
|
||||
log.Debug().
|
||||
Time("requested_expiry", regReq.Expiry).
|
||||
Str("node", regReq.Hostinfo.Hostname).
|
||||
Msg("Ignoring expired expiry time from auth key registration")
|
||||
}
|
||||
|
||||
ipv4, ipv6, err := s.ipAlloc.Next()
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("allocating IPs: %w", err)
|
||||
return nil, change.EmptySet, false, fmt.Errorf("allocating IPs: %w", err)
|
||||
}
|
||||
|
||||
node, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
@@ -735,18 +818,38 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
return node, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("writing node to database: %w", err)
|
||||
return nil, change.EmptySet, false, fmt.Errorf("writing node to database: %w", err)
|
||||
}
|
||||
|
||||
// Check if this is a logout request for an ephemeral node
|
||||
if !regReq.Expiry.IsZero() && regReq.Expiry.Before(time.Now()) && pak.Ephemeral {
|
||||
// This is a logout request for an ephemeral node, delete it immediately
|
||||
c, err := s.DeleteNode(node)
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, false, fmt.Errorf("deleting ephemeral node during logout: %w", err)
|
||||
}
|
||||
return nil, c, false, nil
|
||||
}
|
||||
|
||||
// Check if policy manager needs updating
|
||||
// This is necessary because we just created a new node.
|
||||
// We need to ensure that the policy manager is aware of this new node.
|
||||
policyChanged, err := s.updatePolicyManagerNodes()
|
||||
// Also update users to ensure all users are known when evaluating policies.
|
||||
usersChanged, err := s.updatePolicyManagerUsers()
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("failed to update policy manager after node registration: %w", err)
|
||||
return nil, change.EmptySet, false, fmt.Errorf("failed to update policy manager users after node registration: %w", err)
|
||||
}
|
||||
|
||||
return node, policyChanged, nil
|
||||
nodesChanged, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, false, fmt.Errorf("failed to update policy manager nodes after node registration: %w", err)
|
||||
}
|
||||
|
||||
policyChanged := usersChanged || nodesChanged
|
||||
|
||||
c := change.NodeAdded(node.ID)
|
||||
|
||||
return node, c, policyChanged, nil
|
||||
}
|
||||
|
||||
// AllocateNextIPs allocates the next available IPv4 and IPv6 addresses.
|
||||
@@ -766,11 +869,15 @@ func (s *State) updatePolicyManagerUsers() (bool, error) {
|
||||
return false, fmt.Errorf("listing users for policy update: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().Int("userCount", len(users)).Msg("Updating policy manager with users")
|
||||
|
||||
changed, err := s.polMan.SetUsers(users)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("updating policy manager users: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().Bool("changed", changed).Msg("Policy manager users updated")
|
||||
|
||||
return changed, nil
|
||||
}
|
||||
|
||||
@@ -835,3 +942,125 @@ func (s *State) autoApproveNodes() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(kradalby): This should just take the node ID?
|
||||
func (s *State) UpdateNodeFromMapRequest(node *types.Node, req tailcfg.MapRequest) (change.ChangeSet, error) {
|
||||
// TODO(kradalby): This is essentially a patch update that could be sent directly to nodes,
|
||||
// which means we could shortcut the whole change thing if there are no other important updates.
|
||||
peerChange := node.PeerChangeFromMapRequest(req)
|
||||
|
||||
node.ApplyPeerChange(&peerChange)
|
||||
|
||||
sendUpdate, routesChanged := hostInfoChanged(node.Hostinfo, req.Hostinfo)
|
||||
|
||||
// The node might not set NetInfo if it has not changed and if
|
||||
// the full HostInfo object is overwritten, the information is lost.
|
||||
// If there is no NetInfo, keep the previous one.
|
||||
// From 1.66 the client only sends it if changed:
|
||||
// https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2
|
||||
// TODO(kradalby): evaluate if we need better comparing of hostinfo
|
||||
// before we take the changes.
|
||||
if req.Hostinfo.NetInfo == nil && node.Hostinfo != nil {
|
||||
req.Hostinfo.NetInfo = node.Hostinfo.NetInfo
|
||||
}
|
||||
node.Hostinfo = req.Hostinfo
|
||||
|
||||
// If there is no changes and nothing to save,
|
||||
// return early.
|
||||
if peerChangeEmpty(peerChange) && !sendUpdate {
|
||||
// mapResponseEndpointUpdates.WithLabelValues("noop").Inc()
|
||||
return change.EmptySet, nil
|
||||
}
|
||||
|
||||
c := change.EmptySet
|
||||
|
||||
// Check if the Hostinfo of the node has changed.
|
||||
// If it has changed, check if there has been a change to
|
||||
// the routable IPs of the host and update them in
|
||||
// the database. Then send a Changed update
|
||||
// (containing the whole node object) to peers to inform about
|
||||
// the route change.
|
||||
// If the hostinfo has changed, but not the routes, just update
|
||||
// hostinfo and let the function continue.
|
||||
if routesChanged {
|
||||
// Auto approve any routes that have been defined in policy as
|
||||
// auto approved. Check if this actually changed the node.
|
||||
_ = s.AutoApproveRoutes(node)
|
||||
|
||||
// Update the routes of the given node in the route manager to
|
||||
// see if an update needs to be sent.
|
||||
c = s.SetNodeRoutes(node.ID, node.SubnetRoutes()...)
|
||||
}
|
||||
|
||||
// Check if there has been a change to Hostname and update them
|
||||
// in the database. Then send a Changed update
|
||||
// (containing the whole node object) to peers to inform about
|
||||
// the hostname change.
|
||||
node.ApplyHostnameFromHostInfo(req.Hostinfo)
|
||||
|
||||
_, policyChange, err := s.SaveNode(node)
|
||||
if err != nil {
|
||||
return change.EmptySet, err
|
||||
}
|
||||
|
||||
if policyChange.IsFull() {
|
||||
c = policyChange
|
||||
}
|
||||
|
||||
if c.Empty() {
|
||||
c = change.NodeAdded(node.ID)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// hostInfoChanged reports if hostInfo has changed in two ways,
|
||||
// - first bool reports if an update needs to be sent to nodes
|
||||
// - second reports if there has been changes to routes
|
||||
// the caller can then use this info to save and update nodes
|
||||
// and routes as needed.
|
||||
func hostInfoChanged(old, new *tailcfg.Hostinfo) (bool, bool) {
|
||||
if old.Equal(new) {
|
||||
return false, false
|
||||
}
|
||||
|
||||
if old == nil && new != nil {
|
||||
return true, true
|
||||
}
|
||||
|
||||
// Routes
|
||||
oldRoutes := make([]netip.Prefix, 0)
|
||||
if old != nil {
|
||||
oldRoutes = old.RoutableIPs
|
||||
}
|
||||
newRoutes := new.RoutableIPs
|
||||
|
||||
tsaddr.SortPrefixes(oldRoutes)
|
||||
tsaddr.SortPrefixes(newRoutes)
|
||||
|
||||
if !xslices.Equal(oldRoutes, newRoutes) {
|
||||
return true, true
|
||||
}
|
||||
|
||||
// Services is mostly useful for discovery and not critical,
|
||||
// except for peerapi, which is how nodes talk to each other.
|
||||
// If peerapi was not part of the initial mapresponse, we
|
||||
// need to make sure its sent out later as it is needed for
|
||||
// Taildrop.
|
||||
// TODO(kradalby): Length comparison is a bit naive, replace.
|
||||
if len(old.Services) != len(new.Services) {
|
||||
return true, false
|
||||
}
|
||||
|
||||
return false, false
|
||||
}
|
||||
|
||||
func peerChangeEmpty(peerChange tailcfg.PeerChange) bool {
|
||||
return peerChange.Key == nil &&
|
||||
peerChange.DiscoKey == nil &&
|
||||
peerChange.Online == nil &&
|
||||
peerChange.Endpoints == nil &&
|
||||
peerChange.DERPRegion == 0 &&
|
||||
peerChange.LastSeen == nil &&
|
||||
peerChange.KeyExpiry == nil
|
||||
}
|
||||
|
183
hscontrol/types/change/change.go
Normal file
183
hscontrol/types/change/change.go
Normal file
@@ -0,0 +1,183 @@
|
||||
//go:generate go tool stringer -type=Change
|
||||
package change
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
)
|
||||
|
||||
type (
|
||||
NodeID = types.NodeID
|
||||
UserID = types.UserID
|
||||
)
|
||||
|
||||
type Change int
|
||||
|
||||
const (
|
||||
ChangeUnknown Change = 0
|
||||
|
||||
// Deprecated: Use specific change instead
|
||||
// Full is a legacy change to ensure places where we
|
||||
// have not yet determined the specific update, can send.
|
||||
Full Change = 9
|
||||
|
||||
// Server changes.
|
||||
Policy Change = 11
|
||||
DERP Change = 12
|
||||
ExtraRecords Change = 13
|
||||
|
||||
// Node changes.
|
||||
NodeCameOnline Change = 21
|
||||
NodeWentOffline Change = 22
|
||||
NodeRemove Change = 23
|
||||
NodeKeyExpiry Change = 24
|
||||
NodeNewOrUpdate Change = 25
|
||||
|
||||
// User changes.
|
||||
UserNewOrUpdate Change = 51
|
||||
UserRemove Change = 52
|
||||
)
|
||||
|
||||
// AlsoSelf reports whether this change should also be sent to the node itself.
|
||||
func (c Change) AlsoSelf() bool {
|
||||
switch c {
|
||||
case NodeRemove, NodeKeyExpiry, NodeNewOrUpdate:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type ChangeSet struct {
|
||||
Change Change
|
||||
|
||||
// SelfUpdateOnly indicates that this change should only be sent
|
||||
// to the node itself, and not to other nodes.
|
||||
// This is used for changes that are not relevant to other nodes.
|
||||
// NodeID must be set if this is true.
|
||||
SelfUpdateOnly bool
|
||||
|
||||
// NodeID if set, is the ID of the node that is being changed.
|
||||
// It must be set if this is a node change.
|
||||
NodeID types.NodeID
|
||||
|
||||
// UserID if set, is the ID of the user that is being changed.
|
||||
// It must be set if this is a user change.
|
||||
UserID types.UserID
|
||||
|
||||
// IsSubnetRouter indicates whether the node is a subnet router.
|
||||
IsSubnetRouter bool
|
||||
}
|
||||
|
||||
func (c *ChangeSet) Validate() error {
|
||||
if c.Change >= NodeCameOnline || c.Change <= NodeNewOrUpdate {
|
||||
if c.NodeID == 0 {
|
||||
return errors.New("ChangeSet.NodeID must be set for node updates")
|
||||
}
|
||||
}
|
||||
|
||||
if c.Change >= UserNewOrUpdate || c.Change <= UserRemove {
|
||||
if c.UserID == 0 {
|
||||
return errors.New("ChangeSet.UserID must be set for user updates")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Empty reports whether the ChangeSet is empty, meaning it does not
|
||||
// represent any change.
|
||||
func (c ChangeSet) Empty() bool {
|
||||
return c.Change == ChangeUnknown && c.NodeID == 0 && c.UserID == 0
|
||||
}
|
||||
|
||||
// IsFull reports whether the ChangeSet represents a full update.
|
||||
func (c ChangeSet) IsFull() bool {
|
||||
return c.Change == Full || c.Change == Policy
|
||||
}
|
||||
|
||||
func (c ChangeSet) AlsoSelf() bool {
|
||||
// If NodeID is 0, it means this ChangeSet is not related to a specific node,
|
||||
// so we consider it as a change that should be sent to all nodes.
|
||||
if c.NodeID == 0 {
|
||||
return true
|
||||
}
|
||||
return c.Change.AlsoSelf() || c.SelfUpdateOnly
|
||||
}
|
||||
|
||||
var (
|
||||
EmptySet = ChangeSet{Change: ChangeUnknown}
|
||||
FullSet = ChangeSet{Change: Full}
|
||||
DERPSet = ChangeSet{Change: DERP}
|
||||
PolicySet = ChangeSet{Change: Policy}
|
||||
ExtraRecordsSet = ChangeSet{Change: ExtraRecords}
|
||||
)
|
||||
|
||||
func FullSelf(id types.NodeID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: Full,
|
||||
SelfUpdateOnly: true,
|
||||
NodeID: id,
|
||||
}
|
||||
}
|
||||
|
||||
func NodeAdded(id types.NodeID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: NodeNewOrUpdate,
|
||||
NodeID: id,
|
||||
}
|
||||
}
|
||||
|
||||
func NodeRemoved(id types.NodeID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: NodeRemove,
|
||||
NodeID: id,
|
||||
}
|
||||
}
|
||||
|
||||
func NodeOnline(id types.NodeID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: NodeCameOnline,
|
||||
NodeID: id,
|
||||
}
|
||||
}
|
||||
|
||||
func NodeOffline(id types.NodeID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: NodeWentOffline,
|
||||
NodeID: id,
|
||||
}
|
||||
}
|
||||
|
||||
func KeyExpiry(id types.NodeID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: NodeKeyExpiry,
|
||||
NodeID: id,
|
||||
}
|
||||
}
|
||||
|
||||
func UserAdded(id types.UserID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: UserNewOrUpdate,
|
||||
UserID: id,
|
||||
}
|
||||
}
|
||||
|
||||
func UserRemoved(id types.UserID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: UserRemove,
|
||||
UserID: id,
|
||||
}
|
||||
}
|
||||
|
||||
func PolicyChange() ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: Policy,
|
||||
}
|
||||
}
|
||||
|
||||
func DERPChange() ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: DERP,
|
||||
}
|
||||
}
|
57
hscontrol/types/change/change_string.go
Normal file
57
hscontrol/types/change/change_string.go
Normal file
@@ -0,0 +1,57 @@
|
||||
// Code generated by "stringer -type=Change"; DO NOT EDIT.
|
||||
|
||||
package change
|
||||
|
||||
import "strconv"
|
||||
|
||||
func _() {
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
var x [1]struct{}
|
||||
_ = x[ChangeUnknown-0]
|
||||
_ = x[Full-9]
|
||||
_ = x[Policy-11]
|
||||
_ = x[DERP-12]
|
||||
_ = x[ExtraRecords-13]
|
||||
_ = x[NodeCameOnline-21]
|
||||
_ = x[NodeWentOffline-22]
|
||||
_ = x[NodeRemove-23]
|
||||
_ = x[NodeKeyExpiry-24]
|
||||
_ = x[NodeNewOrUpdate-25]
|
||||
_ = x[UserNewOrUpdate-51]
|
||||
_ = x[UserRemove-52]
|
||||
}
|
||||
|
||||
const (
|
||||
_Change_name_0 = "ChangeUnknown"
|
||||
_Change_name_1 = "Full"
|
||||
_Change_name_2 = "PolicyDERPExtraRecords"
|
||||
_Change_name_3 = "NodeCameOnlineNodeWentOfflineNodeRemoveNodeKeyExpiryNodeNewOrUpdate"
|
||||
_Change_name_4 = "UserNewOrUpdateUserRemove"
|
||||
)
|
||||
|
||||
var (
|
||||
_Change_index_2 = [...]uint8{0, 6, 10, 22}
|
||||
_Change_index_3 = [...]uint8{0, 14, 29, 39, 52, 67}
|
||||
_Change_index_4 = [...]uint8{0, 15, 25}
|
||||
)
|
||||
|
||||
func (i Change) String() string {
|
||||
switch {
|
||||
case i == 0:
|
||||
return _Change_name_0
|
||||
case i == 9:
|
||||
return _Change_name_1
|
||||
case 11 <= i && i <= 13:
|
||||
i -= 11
|
||||
return _Change_name_2[_Change_index_2[i]:_Change_index_2[i+1]]
|
||||
case 21 <= i && i <= 25:
|
||||
i -= 21
|
||||
return _Change_name_3[_Change_index_3[i]:_Change_index_3[i+1]]
|
||||
case 51 <= i && i <= 52:
|
||||
i -= 51
|
||||
return _Change_name_4[_Change_index_4[i]:_Change_index_4[i+1]]
|
||||
default:
|
||||
return "Change(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
}
|
@@ -1,16 +1,16 @@
|
||||
//go:generate go run tailscale.com/cmd/viewer --type=User,Node,PreAuthKey
|
||||
|
||||
//go:generate go tool viewer --type=User,Node,PreAuthKey
|
||||
package types
|
||||
|
||||
//go:generate go run tailscale.com/cmd/viewer --type=User,Node,PreAuthKey
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/ctxkey"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -150,18 +150,6 @@ func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
NotifyOriginKey = ctxkey.New("notify.origin", "")
|
||||
NotifyHostnameKey = ctxkey.New("notify.hostname", "")
|
||||
)
|
||||
|
||||
func NotifyCtx(ctx context.Context, origin, hostname string) context.Context {
|
||||
ctx2, _ := context.WithTimeout(ctx, 3*time.Second)
|
||||
ctx2 = NotifyOriginKey.WithValue(ctx2, origin)
|
||||
ctx2 = NotifyHostnameKey.WithValue(ctx2, hostname)
|
||||
return ctx2
|
||||
}
|
||||
|
||||
const RegistrationIDLength = 24
|
||||
|
||||
type RegistrationID string
|
||||
@@ -199,3 +187,20 @@ type RegisterNode struct {
|
||||
Node Node
|
||||
Registered chan *Node
|
||||
}
|
||||
|
||||
// DefaultBatcherWorkers returns the default number of batcher workers.
|
||||
// Default to 3/4 of CPU cores, minimum 1, no maximum.
|
||||
func DefaultBatcherWorkers() int {
|
||||
return DefaultBatcherWorkersFor(runtime.NumCPU())
|
||||
}
|
||||
|
||||
// DefaultBatcherWorkersFor returns the default number of batcher workers for a given CPU count.
|
||||
// Default to 3/4 of CPU cores, minimum 1, no maximum.
|
||||
func DefaultBatcherWorkersFor(cpuCount int) int {
|
||||
defaultWorkers := (cpuCount * 3) / 4
|
||||
if defaultWorkers < 1 {
|
||||
defaultWorkers = 1
|
||||
}
|
||||
|
||||
return defaultWorkers
|
||||
}
|
||||
|
36
hscontrol/types/common_test.go
Normal file
36
hscontrol/types/common_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDefaultBatcherWorkersFor(t *testing.T) {
|
||||
tests := []struct {
|
||||
cpuCount int
|
||||
expected int
|
||||
}{
|
||||
{1, 1}, // (1*3)/4 = 0, should be minimum 1
|
||||
{2, 1}, // (2*3)/4 = 1
|
||||
{4, 3}, // (4*3)/4 = 3
|
||||
{8, 6}, // (8*3)/4 = 6
|
||||
{12, 9}, // (12*3)/4 = 9
|
||||
{16, 12}, // (16*3)/4 = 12
|
||||
{20, 15}, // (20*3)/4 = 15
|
||||
{24, 18}, // (24*3)/4 = 18
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result := DefaultBatcherWorkersFor(test.cpuCount)
|
||||
if result != test.expected {
|
||||
t.Errorf("DefaultBatcherWorkersFor(%d) = %d, expected %d", test.cpuCount, result, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultBatcherWorkers(t *testing.T) {
|
||||
// Just verify it returns a valid value (>= 1)
|
||||
result := DefaultBatcherWorkers()
|
||||
if result < 1 {
|
||||
t.Errorf("DefaultBatcherWorkers() = %d, expected value >= 1", result)
|
||||
}
|
||||
}
|
@@ -234,6 +234,7 @@ type Tuning struct {
|
||||
NotifierSendTimeout time.Duration
|
||||
BatchChangeDelay time.Duration
|
||||
NodeMapSessionBufferedChanSize int
|
||||
BatcherWorkers int
|
||||
}
|
||||
|
||||
func validatePKCEMethod(method string) error {
|
||||
@@ -991,6 +992,12 @@ func LoadServerConfig() (*Config, error) {
|
||||
NodeMapSessionBufferedChanSize: viper.GetInt(
|
||||
"tuning.node_mapsession_buffered_chan_size",
|
||||
),
|
||||
BatcherWorkers: func() int {
|
||||
if workers := viper.GetInt("tuning.batcher_workers"); workers > 0 {
|
||||
return workers
|
||||
}
|
||||
return DefaultBatcherWorkers()
|
||||
}(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
@@ -431,6 +431,11 @@ func (node *Node) SubnetRoutes() []netip.Prefix {
|
||||
return routes
|
||||
}
|
||||
|
||||
// IsSubnetRouter reports if the node has any subnet routes.
|
||||
func (node *Node) IsSubnetRouter() bool {
|
||||
return len(node.SubnetRoutes()) > 0
|
||||
}
|
||||
|
||||
func (node *Node) String() string {
|
||||
return node.Hostname
|
||||
}
|
||||
@@ -669,6 +674,13 @@ func (v NodeView) SubnetRoutes() []netip.Prefix {
|
||||
return v.ж.SubnetRoutes()
|
||||
}
|
||||
|
||||
func (v NodeView) IsSubnetRouter() bool {
|
||||
if !v.Valid() {
|
||||
return false
|
||||
}
|
||||
return v.ж.IsSubnetRouter()
|
||||
}
|
||||
|
||||
func (v NodeView) AppendToIPSet(build *netipx.IPSetBuilder) {
|
||||
if !v.Valid() {
|
||||
return
|
||||
|
@@ -1,17 +1,16 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/rs/zerolog/log"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
type PAKError string
|
||||
|
||||
func (e PAKError) Error() string { return string(e) }
|
||||
func (e PAKError) Unwrap() error { return fmt.Errorf("preauth key error: %w", e) }
|
||||
|
||||
// PreAuthKey describes a pre-authorization key usable in a particular user.
|
||||
type PreAuthKey struct {
|
||||
@@ -60,6 +59,21 @@ func (pak *PreAuthKey) Validate() error {
|
||||
if pak == nil {
|
||||
return PAKError("invalid authkey")
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("key", pak.Key).
|
||||
Bool("hasExpiration", pak.Expiration != nil).
|
||||
Time("expiration", func() time.Time {
|
||||
if pak.Expiration != nil {
|
||||
return *pak.Expiration
|
||||
}
|
||||
return time.Time{}
|
||||
}()).
|
||||
Time("now", time.Now()).
|
||||
Bool("reusable", pak.Reusable).
|
||||
Bool("used", pak.Used).
|
||||
Msg("PreAuthKey.Validate: checking key")
|
||||
|
||||
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
|
||||
return PAKError("authkey expired")
|
||||
}
|
||||
|
@@ -5,6 +5,8 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"tailscale.com/util/dnsname"
|
||||
"tailscale.com/util/must"
|
||||
)
|
||||
|
||||
func TestCheckForFQDNRules(t *testing.T) {
|
||||
@@ -102,59 +104,16 @@ func TestConvertWithFQDNRules(t *testing.T) {
|
||||
func TestMagicDNSRootDomains100(t *testing.T) {
|
||||
domains := GenerateIPv4DNSRootDomain(netip.MustParsePrefix("100.64.0.0/10"))
|
||||
|
||||
found := false
|
||||
for _, domain := range domains {
|
||||
if domain == "64.100.in-addr.arpa." {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
|
||||
found = false
|
||||
for _, domain := range domains {
|
||||
if domain == "100.100.in-addr.arpa." {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
|
||||
found = false
|
||||
for _, domain := range domains {
|
||||
if domain == "127.100.in-addr.arpa." {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("64.100.in-addr.arpa.")))
|
||||
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("100.100.in-addr.arpa.")))
|
||||
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("127.100.in-addr.arpa.")))
|
||||
}
|
||||
|
||||
func TestMagicDNSRootDomains172(t *testing.T) {
|
||||
domains := GenerateIPv4DNSRootDomain(netip.MustParsePrefix("172.16.0.0/16"))
|
||||
|
||||
found := false
|
||||
for _, domain := range domains {
|
||||
if domain == "0.16.172.in-addr.arpa." {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
|
||||
found = false
|
||||
for _, domain := range domains {
|
||||
if domain == "255.16.172.in-addr.arpa." {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("0.16.172.in-addr.arpa.")))
|
||||
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("255.16.172.in-addr.arpa.")))
|
||||
}
|
||||
|
||||
// Happens when netmask is a multiple of 4 bits (sounds likely).
|
||||
|
@@ -143,7 +143,7 @@ func ParseTraceroute(output string) (Traceroute, error) {
|
||||
|
||||
// Parse latencies
|
||||
for j := 5; j <= 7; j++ {
|
||||
if matches[j] != "" {
|
||||
if j < len(matches) && matches[j] != "" {
|
||||
ms, err := strconv.ParseFloat(matches[j], 64)
|
||||
if err != nil {
|
||||
return Traceroute{}, fmt.Errorf("parsing latency: %w", err)
|
||||
|
Reference in New Issue
Block a user