Compare commits

...

15 Commits

Author SHA1 Message Date
Juan Font
7bdf364e51 Clarify relation with Tailscale 2024-04-22 10:14:47 +00:00
Juan Font
bd047928f7 Move pprof to metrics router (#1902) 2024-04-21 22:08:59 +02:00
ChengenH
9375b09206 chore: use errors.New to replace fmt.Errorf with no parameters will much better
Signed-off-by: ChengenH <hce19970702@gmail.com>
2024-04-21 20:23:25 +02:00
Kristoffer Dalby
ba614a5e6c metrics, tuning in tests, db cleanups, fix concurrency issue (#1895) 2024-04-21 18:28:17 +02:00
oftenoccur
7d8178406d chore: fix function names in comment (#1866)
* chore: fix function names in comment

Signed-off-by: oftenoccur <ezc5@sina.com>

---------

Signed-off-by: oftenoccur <ezc5@sina.com>
Co-authored-by: ohdearaugustin <ohdearaugustin@users.noreply.github.com>
2024-04-21 18:19:38 +02:00
ohdearaugustin
8394208856 fix prettier 2024-04-21 17:32:41 +02:00
Arnaud Dezandee
803269a64c docs(readme): change contributors section (#1889) 2024-04-21 16:48:33 +02:00
Carson Yang
d6ec31c4e0 docs: Add docs for running headscale on sealos (#1666)
* docs: Add docs for running headscale on sealos

Signed-off-by: Carson Yang <yangchuansheng33@gmail.com>

* run prettier

---------

Signed-off-by: Carson Yang <yangchuansheng33@gmail.com>
Co-authored-by: ohdearaugustin <ohdearaugustin@users.noreply.github.com>
2024-04-21 16:43:31 +02:00
Juan Font
68503581a0 Add test stage to docs (#1893)
* Add test stage to docs

Add new file with docs tets

Run only in pulls

* set explicit python version

* Revert "set explicit python version"

This reverts commit 4dd7b81f26.

* docs/requirements: update mkdocs-material

---------

Co-authored-by: ohdearaugustin <ohdearaugustin@users.noreply.github.com>
2024-04-21 16:33:22 +02:00
Juan Font
e2afd30b1c Add the latest UI to the website 2024-04-18 14:55:59 +02:00
Juan Font
c906aaf927 Allow to remove forced tags of a node
Set as empty StringList
2024-04-18 09:55:55 +02:00
Juan Font
580f96ce83 Remove unused node check interval 2024-04-17 20:20:44 +02:00
Juan Font
c4c8cfe5ea Fix crash when a prefix family was empty 2024-04-17 15:28:06 +02:00
Kristoffer Dalby
40953727cf fix ip migration
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-04-17 12:36:12 +02:00
Cas de Reuver
d4af0c386c Log available update as warning (#1877) 2024-04-17 11:22:53 +02:00
42 changed files with 543 additions and 1318 deletions

View File

@@ -1,36 +0,0 @@
name: Contributors
on:
push:
branches:
- main
workflow_dispatch:
jobs:
add-contributors:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Delete upstream contributor branch
# Allow continue on failure to account for when the
# upstream branch is deleted or does not exist.
continue-on-error: true
run: git push origin --delete update-contributors
- name: Create up-to-date contributors branch
run: git checkout -B update-contributors
- name: Push empty contributors branch
run: git push origin update-contributors
- name: Switch back to main
run: git checkout main
- uses: BobAnkh/add-contributors@v0.2.2
with:
CONTRIBUTOR: "## Contributors"
COLUMN_PER_ROW: "6"
ACCESS_TOKEN: ${{secrets.GITHUB_TOKEN}}
IMG_WIDTH: "100"
FONT_SIZE: "14"
PATH: "/README.md"
COMMIT_MESSAGE: "docs(README): update contributors"
AVATAR_SHAPE: "round"
BRANCH: "update-contributors"
PULL_REQUEST: "main"

27
.github/workflows/docs-test.yml vendored Normal file
View File

@@ -0,0 +1,27 @@
name: Test documentation build
on: [pull_request]
concurrency:
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install python
uses: actions/setup-python@v4
with:
python-version: 3.x
- name: Setup cache
uses: actions/cache@v2
with:
key: ${{ github.ref }}
path: .cache
- name: Setup dependencies
run: pip install -r docs/requirements.txt
- name: Build docs
run: mkdocs build --strict

View File

@@ -55,6 +55,7 @@ after improving the test harness as part of adopting [#1460](https://github.com/
- Added the possibility to manually create a DERP-map entry which can be customized, instead of automatically creating it. [#1565](https://github.com/juanfont/headscale/pull/1565) - Added the possibility to manually create a DERP-map entry which can be customized, instead of automatically creating it. [#1565](https://github.com/juanfont/headscale/pull/1565)
- Add support for deleting api keys [#1702](https://github.com/juanfont/headscale/pull/1702) - Add support for deleting api keys [#1702](https://github.com/juanfont/headscale/pull/1702)
- Add command to backfill IP addresses for nodes missing IPs from configured prefixes. [#1869](https://github.com/juanfont/headscale/pull/1869) - Add command to backfill IP addresses for nodes missing IPs from configured prefixes. [#1869](https://github.com/juanfont/headscale/pull/1869)
- Log available update as warning [#1877](https://github.com/juanfont/headscale/pull/1877)
## 0.22.3 (2023-05-12) ## 0.22.3 (2023-05-12)

1042
README.md

File diff suppressed because it is too large Load Diff

View File

@@ -78,7 +78,7 @@ func initConfig() {
res, err := latest.Check(githubTag, Version) res, err := latest.Check(githubTag, Version)
if err == nil && res.Outdated { if err == nil && res.Outdated {
//nolint //nolint
fmt.Printf( log.Warn().Msgf(
"An updated version of Headscale has been found (%s vs. your current %s). Check it out https://github.com/juanfont/headscale/releases\n", "An updated version of Headscale has been found (%s vs. your current %s). Check it out https://github.com/juanfont/headscale/releases\n",
res.Current, res.Current,
Version, Version,

View File

@@ -137,12 +137,6 @@ disable_check_updates: false
# Time before an inactive ephemeral node is deleted? # Time before an inactive ephemeral node is deleted?
ephemeral_node_inactivity_timeout: 30m ephemeral_node_inactivity_timeout: 30m
# Period to check for node updates within the tailnet. A value too low will severely affect
# CPU consumption of Headscale. A value too high (over 60s) will cause problems
# for the nodes, as they won't get updates or keep alive messages frequently enough.
# In case of doubts, do not touch the default 10s.
node_update_check_interval: 10s
database: database:
type: sqlite type: sqlite

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

View File

@@ -1,5 +1,4 @@
cairosvg~=2.7.1 cairosvg~=2.7.1
mkdocs-material~=9.4.14 mkdocs-material~=9.5.18
mkdocs-minify-plugin~=0.7.1 mkdocs-minify-plugin~=0.7.1
pillow~=10.1.0 pillow~=10.1.0

View File

@@ -0,0 +1,136 @@
# Running headscale on Sealos
!!! warning "Community documentation"
This page is not actively maintained by the headscale authors and is
written by community members. It is _not_ verified by `headscale` developers.
**It might be outdated and it might miss necessary steps**.
## Goal
This documentation has the goal of showing a user how-to run `headscale` on Sealos.
## Running headscale server
1. Click the following prebuilt template(version [0.23.0-alpha2](https://github.com/juanfont/headscale/releases/tag/v0.23.0-alpha2)):
[![](https://cdn.jsdelivr.net/gh/labring-actions/templates@main/Deploy-on-Sealos.svg)](https://cloud.sealos.io/?openapp=system-template%3FtemplateName%3Dheadscale)
2. Click "Deploy Application" on the template page to start deployment. Upon completion, two applications appear: Headscale, and its [visual interface](https://github.com/GoodiesHQ/headscale-admin).
3. Once deployment concludes, click 'Details' on the Headscale application page to navigate to the application's details.
4. Wait for the application's status to switch to running. For accessing the headscale server, the Public Address associated with port 8080 is the address of the headscale server. To access the Headscale console, simply append `/admin/` to the Headscale public URL.
![](./images/headscale-sealos-url.png)
5. Click on 'Terminal' button on the right side of the details to access the Terminal of the headscale application. then create a user ([tailnet](https://tailscale.com/kb/1136/tailnet/)):
```bash
headscale users create myfirstuser
```
### Register a machine (normal login)
On a client machine, execute the `tailscale` login command:
```bash
# replace <YOUR_HEADSCALE_URL> with the public domain provided by Sealos
tailscale up --login-server YOUR_HEADSCALE_URL
```
To register a machine when running headscale in [Sealos](https://sealos.io), click on 'Terminal' button on the right side of the headscale application's detail page to access the Terminal of the headscale application, then take the headscale command:
```bash
headscale --user myfirstuser nodes register --key <YOU_+MACHINE_KEY>
```
### Register machine using a pre authenticated key
click on 'Terminal' button on the right side of the headscale application's detail page to access the Terminal of the headscale application, then generate a key using the command line:
```bash
headscale --user myfirstuser preauthkeys create --reusable --expiration 24h
```
This will return a pre-authenticated key that can be used to connect a node to `headscale` during the `tailscale` command:
```bash
tailscale up --login-server <YOUR_HEADSCALE_URL> --authkey <YOUR_AUTH_KEY>
```
## Controlling headscale with remote CLI
This documentation has the goal of showing a user how-to set control a headscale instance from a remote machine with the headscale command line binary.
### Create an API key
We need to create an API key to authenticate our remote headscale when using it from our workstation.
To create a API key, click on 'Terminal' button on the right side of the headscale application's detail page to access the Terminal of the headscale application, then generate a key:
```bash
headscale apikeys create --expiration 90d
```
Copy the output of the command and save it for later. Please note that you can not retrieve a key again, if the key is lost, expire the old one, and create a new key.
To list the keys currently assosicated with the server:
```bash
headscale apikeys list
```
and to expire a key:
```bash
headscale apikeys expire --prefix "<PREFIX>"
```
### Download and configure `headscale` client
1. Download the latest [`headscale` binary from GitHub's release page](https://github.com/juanfont/headscale/releases):
2. Put the binary somewhere in your `PATH`, e.g. `/usr/local/bin/headscale`
3. Make `headscale` executable:
```shell
chmod +x /usr/local/bin/headscale
```
4. Configure the CLI through Environment Variables
```shell
export HEADSCALE_CLI_ADDRESS="<HEADSCALE ADDRESS>:443"
export HEADSCALE_CLI_API_KEY="<API KEY FROM PREVIOUS STAGE>"
```
In the headscale application's detail page, The Public Address corresponding to port 50443 corresponds to the value of <HEADSCALE ADDRESS>.
![](./images/headscale-sealos-grpc-url.png)
for example:
```shell
export HEADSCALE_CLI_ADDRESS="pwnjnnly.cloud.sealos.io:443"
export HEADSCALE_CLI_API_KEY="abcde12345"
```
This will tell the `headscale` binary to connect to a remote instance, instead of looking
for a local instance.
The API key is needed to make sure that your are allowed to access the server. The key is _not_
needed when running directly on the server, as the connection is local.
1. Test the connection
Let us run the headscale command to verify that we can connect by listing our nodes:
```shell
headscale nodes list
```
You should now be able to see a list of your nodes from your workstation, and you can
now control the `headscale` server from your workstation.
> Reference: [Headscale Deployment and Usage Guide: Mastering Tailscale's Self-Hosting Basics](https://icloudnative.io/en/posts/how-to-set-up-or-migrate-headscale/)

View File

@@ -5,10 +5,11 @@
This page contains community contributions. The projects listed here are not This page contains community contributions. The projects listed here are not
maintained by the Headscale authors and are written by community members. maintained by the Headscale authors and are written by community members.
| Name | Repository Link | Description | Status | | Name | Repository Link | Description | Status |
| --------------- | ------------------------------------------------------- | ------------------------------------------------------------------------- | ------ | | --------------- | ------------------------------------------------------- | --------------------------------------------------------------------------- | ------ |
| headscale-webui | [Github](https://github.com/ifargle/headscale-webui) | A simple Headscale web UI for small-scale deployments. | Alpha | | headscale-webui | [Github](https://github.com/ifargle/headscale-webui) | A simple Headscale web UI for small-scale deployments. | Alpha |
| headscale-ui | [Github](https://github.com/gurucomputing/headscale-ui) | A web frontend for the headscale Tailscale-compatible coordination server | Alpha | | headscale-ui | [Github](https://github.com/gurucomputing/headscale-ui) | A web frontend for the headscale Tailscale-compatible coordination server | Alpha |
| HeadscaleUi | [GitHub](https://github.com/simcu/headscale-ui) | A static headscale admin ui, no backend enviroment required | Alpha | | HeadscaleUi | [GitHub](https://github.com/simcu/headscale-ui) | A static headscale admin ui, no backend enviroment required | Alpha |
| headscale-admin | [Github](https://github.com/GoodiesHQ/headscale-admin) | Headscale-Admin is meant to be a simple, modern web interface for Headscale | Beta |
You can ask for support on our dedicated [Discord channel](https://discord.com/channels/896711691637780480/1105842846386356294). You can ask for support on our dedicated [Discord channel](https://discord.com/channels/896711691637780480/1105842846386356294).

View File

@@ -225,7 +225,7 @@ func (h *Headscale) deleteExpireEphemeralNodes(milliSeconds int64) {
for range ticker.C { for range ticker.C {
var removed []types.NodeID var removed []types.NodeID
var changed []types.NodeID var changed []types.NodeID
if err := h.db.DB.Transaction(func(tx *gorm.DB) error { if err := h.db.Write(func(tx *gorm.DB) error {
removed, changed = db.DeleteExpiredEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout) removed, changed = db.DeleteExpiredEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout)
return nil return nil
@@ -263,7 +263,7 @@ func (h *Headscale) expireExpiredMachines(intervalMs int64) {
var changed bool var changed bool
for range ticker.C { for range ticker.C {
if err := h.db.DB.Transaction(func(tx *gorm.DB) error { if err := h.db.Write(func(tx *gorm.DB) error {
lastCheck, update, changed = db.ExpireExpiredNodes(tx, lastCheck) lastCheck, update, changed = db.ExpireExpiredNodes(tx, lastCheck)
return nil return nil
@@ -452,7 +452,7 @@ func (h *Headscale) ensureUnixSocketIsAbsent() error {
func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
router := mux.NewRouter() router := mux.NewRouter()
router.PathPrefix("/debug/pprof/").Handler(http.DefaultServeMux) router.Use(prometheusMiddleware)
router.HandleFunc(ts2021UpgradePath, h.NoiseUpgradeHandler).Methods(http.MethodPost) router.HandleFunc(ts2021UpgradePath, h.NoiseUpgradeHandler).Methods(http.MethodPost)
@@ -508,7 +508,7 @@ func (h *Headscale) Serve() error {
// Fetch an initial DERP Map before we start serving // Fetch an initial DERP Map before we start serving
h.DERPMap = derp.GetDERPMap(h.cfg.DERP) h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier.ConnectedMap()) h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier)
if h.cfg.DERP.ServerEnabled { if h.cfg.DERP.ServerEnabled {
// When embedded DERP is enabled we always need a STUN server // When embedded DERP is enabled we always need a STUN server
@@ -680,7 +680,7 @@ func (h *Headscale) Serve() error {
// HTTP setup // HTTP setup
// //
// This is the regular router that we expose // This is the regular router that we expose
// over our main Addr. It also serves the legacy Tailcale API // over our main Addr
router := h.createRouter(grpcGatewayMux) router := h.createRouter(grpcGatewayMux)
httpServer := &http.Server{ httpServer := &http.Server{
@@ -710,11 +710,10 @@ func (h *Headscale) Serve() error {
Msgf("listening and serving HTTP on: %s", h.cfg.Addr) Msgf("listening and serving HTTP on: %s", h.cfg.Addr)
debugMux := http.NewServeMux() debugMux := http.NewServeMux()
debugMux.Handle("/debug/pprof/", http.DefaultServeMux)
debugMux.HandleFunc("/debug/notifier", func(w http.ResponseWriter, r *http.Request) { debugMux.HandleFunc("/debug/notifier", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(h.nodeNotifier.String())) w.Write([]byte(h.nodeNotifier.String()))
return
}) })
debugMux.HandleFunc("/debug/mapresp", func(w http.ResponseWriter, r *http.Request) { debugMux.HandleFunc("/debug/mapresp", func(w http.ResponseWriter, r *http.Request) {
h.mapSessionMu.Lock() h.mapSessionMu.Lock()
@@ -728,8 +727,6 @@ func (h *Headscale) Serve() error {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(b.String())) w.Write([]byte(b.String()))
return
}) })
debugMux.Handle("/metrics", promhttp.Handler()) debugMux.Handle("/metrics", promhttp.Handler())

View File

@@ -273,8 +273,6 @@ func (h *Headscale) handleAuthKey(
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
http.Error(writer, "Internal server error", http.StatusInternalServerError) http.Error(writer, "Internal server error", http.StatusInternalServerError)
nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
Inc()
return return
} }
@@ -294,13 +292,6 @@ func (h *Headscale) handleAuthKey(
Str("node", registerRequest.Hostinfo.Hostname). Str("node", registerRequest.Hostinfo.Hostname).
Msg("Failed authentication via AuthKey") Msg("Failed authentication via AuthKey")
if pak != nil {
nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
Inc()
} else {
nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", "unknown").Inc()
}
return return
} }
@@ -404,15 +395,13 @@ func (h *Headscale) handleAuthKey(
Caller(). Caller().
Err(err). Err(err).
Msg("could not register node") Msg("could not register node")
nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError) http.Error(writer, "Internal server error", http.StatusInternalServerError)
return return
} }
} }
err = h.db.DB.Transaction(func(tx *gorm.DB) error { h.db.Write(func(tx *gorm.DB) error {
return db.UsePreAuthKey(tx, pak) return db.UsePreAuthKey(tx, pak)
}) })
if err != nil { if err != nil {
@@ -420,8 +409,6 @@ func (h *Headscale) handleAuthKey(
Caller(). Caller().
Err(err). Err(err).
Msg("Failed to use pre-auth key") Msg("Failed to use pre-auth key")
nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError) http.Error(writer, "Internal server error", http.StatusInternalServerError)
return return
@@ -440,14 +427,10 @@ func (h *Headscale) handleAuthKey(
Str("node", registerRequest.Hostinfo.Hostname). Str("node", registerRequest.Hostinfo.Hostname).
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name).
Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError) http.Error(writer, "Internal server error", http.StatusInternalServerError)
return return
} }
nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "success", pak.User.Name).
Inc()
writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody) _, err = writer.Write(respBody)
@@ -563,7 +546,7 @@ func (h *Headscale) handleNodeLogOut(
} }
if node.IsEphemeral() { if node.IsEphemeral() {
changedNodes, err := h.db.DeleteNode(&node, h.nodeNotifier.ConnectedMap()) changedNodes, err := h.db.DeleteNode(&node, h.nodeNotifier.LikelyConnectedMap())
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
@@ -616,14 +599,10 @@ func (h *Headscale) handleNodeWithValidRegistration(
Caller(). Caller().
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
nodeRegistrations.WithLabelValues("update", "web", "error", node.User.Name).
Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError) http.Error(writer, "Internal server error", http.StatusInternalServerError)
return return
} }
nodeRegistrations.WithLabelValues("update", "web", "success", node.User.Name).
Inc()
writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
@@ -654,7 +633,7 @@ func (h *Headscale) handleNodeKeyRefresh(
Str("node", node.Hostname). Str("node", node.Hostname).
Msg("We have the OldNodeKey in the database. This is a key refresh") Msg("We have the OldNodeKey in the database. This is a key refresh")
err := h.db.DB.Transaction(func(tx *gorm.DB) error { err := h.db.Write(func(tx *gorm.DB) error {
return db.NodeSetNodeKey(tx, &node, registerRequest.NodeKey) return db.NodeSetNodeKey(tx, &node, registerRequest.NodeKey)
}) })
if err != nil { if err != nil {
@@ -737,14 +716,10 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
Caller(). Caller().
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
nodeRegistrations.WithLabelValues("reauth", "web", "error", node.User.Name).
Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError) http.Error(writer, "Internal server error", http.StatusInternalServerError)
return return
} }
nodeRegistrations.WithLabelValues("reauth", "web", "success", node.User.Name).
Inc()
writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)

View File

@@ -33,7 +33,6 @@ func (ns *noiseServer) NoiseRegistrationHandler(
Caller(). Caller().
Err(err). Err(err).
Msg("Cannot parse RegisterRequest") Msg("Cannot parse RegisterRequest")
nodeRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
http.Error(writer, "Internal error", http.StatusInternalServerError) http.Error(writer, "Internal error", http.StatusInternalServerError)
return return

View File

@@ -356,7 +356,7 @@ func NewHeadscaleDatabase(
addrs := strings.Split(node.Addresses, ",") addrs := strings.Split(node.Addresses, ",")
if len(addrs) == 0 { if len(addrs) == 0 {
fmt.Errorf("no addresses found for node(%d)", node.ID) return fmt.Errorf("no addresses found for node(%d)", node.ID)
} }
var v4 *netip.Addr var v4 *netip.Addr
@@ -377,9 +377,18 @@ func NewHeadscaleDatabase(
} }
} }
err = tx.Save(&types.Node{ID: types.NodeID(node.ID), IPv4: v4, IPv6: v6}).Error if v4 != nil {
if err != nil { err = tx.Model(&types.Node{}).Where("id = ?", node.ID).Update("ipv4", v4.String()).Error
return fmt.Errorf("saving ip addresses to new columns: %w", err) if err != nil {
return fmt.Errorf("saving ip addresses to new columns: %w", err)
}
}
if v6 != nil {
err = tx.Model(&types.Node{}).Where("id = ?", node.ID).Update("ipv6", v6.String()).Error
if err != nil {
return fmt.Errorf("saving ip addresses to new columns: %w", err)
}
} }
} }

View File

@@ -10,6 +10,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
"github.com/puzpuzpuz/xsync/v3"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@@ -206,6 +207,11 @@ func SetTags(
tags []string, tags []string,
) error { ) error {
if len(tags) == 0 { if len(tags) == 0 {
// if no tags are provided, we remove all forced tags
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", types.StringList{}).Error; err != nil {
return fmt.Errorf("failed to remove tags for node in the database: %w", err)
}
return nil return nil
} }
@@ -255,9 +261,9 @@ func NodeSetExpiry(tx *gorm.DB,
return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error
} }
func (hsdb *HSDatabase) DeleteNode(node *types.Node, isConnected types.NodeConnectedMap) ([]types.NodeID, error) { func (hsdb *HSDatabase) DeleteNode(node *types.Node, isLikelyConnected *xsync.MapOf[types.NodeID, bool]) ([]types.NodeID, error) {
return Write(hsdb.DB, func(tx *gorm.DB) ([]types.NodeID, error) { return Write(hsdb.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
return DeleteNode(tx, node, isConnected) return DeleteNode(tx, node, isLikelyConnected)
}) })
} }
@@ -265,9 +271,9 @@ func (hsdb *HSDatabase) DeleteNode(node *types.Node, isConnected types.NodeConne
// Caller is responsible for notifying all of change. // Caller is responsible for notifying all of change.
func DeleteNode(tx *gorm.DB, func DeleteNode(tx *gorm.DB,
node *types.Node, node *types.Node,
isConnected types.NodeConnectedMap, isLikelyConnected *xsync.MapOf[types.NodeID, bool],
) ([]types.NodeID, error) { ) ([]types.NodeID, error) {
changed, err := deleteNodeRoutes(tx, node, isConnected) changed, err := deleteNodeRoutes(tx, node, isLikelyConnected)
if err != nil { if err != nil {
return changed, err return changed, err
} }

View File

@@ -11,6 +11,7 @@ import (
"github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/puzpuzpuz/xsync/v3"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
@@ -120,7 +121,7 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
} }
db.DB.Save(&node) db.DB.Save(&node)
_, err = db.DeleteNode(&node, types.NodeConnectedMap{}) _, err = db.DeleteNode(&node, xsync.NewMapOf[types.NodeID, bool]())
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
_, err = db.getNode(user.Name, "testnode3") _, err = db.getNode(user.Name, "testnode3")
@@ -386,6 +387,13 @@ func (s *Suite) TestSetTags(c *check.C) {
check.DeepEquals, check.DeepEquals,
types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}), types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}),
) )
// test removing tags
err = db.SetTags(node.ID, []string{})
c.Assert(err, check.IsNil)
node, err = db.getNode("test", "testnode")
c.Assert(err, check.IsNil)
c.Assert(node.ForcedTags, check.DeepEquals, types.StringList([]string{}))
} }
func TestHeadscale_generateGivenName(t *testing.T) { func TestHeadscale_generateGivenName(t *testing.T) {

View File

@@ -147,7 +147,7 @@ func (*Suite) TestEphemeralKeyReusable(c *check.C) {
_, err = db.getNode("test7", "testest") _, err = db.getNode("test7", "testest")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
db.DB.Transaction(func(tx *gorm.DB) error { db.Write(func(tx *gorm.DB) error {
DeleteExpiredEphemeralNodes(tx, time.Second*20) DeleteExpiredEphemeralNodes(tx, time.Second*20)
return nil return nil
}) })
@@ -181,7 +181,7 @@ func (*Suite) TestEphemeralKeyNotReusable(c *check.C) {
_, err = db.getNode("test7", "testest") _, err = db.getNode("test7", "testest")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
db.DB.Transaction(func(tx *gorm.DB) error { db.Write(func(tx *gorm.DB) error {
DeleteExpiredEphemeralNodes(tx, time.Second*20) DeleteExpiredEphemeralNodes(tx, time.Second*20)
return nil return nil
}) })

View File

@@ -8,6 +8,7 @@ import (
"github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/puzpuzpuz/xsync/v3"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/util/set" "tailscale.com/util/set"
@@ -126,7 +127,7 @@ func EnableRoute(tx *gorm.DB, id uint64) (*types.StateUpdate, error) {
func DisableRoute(tx *gorm.DB, func DisableRoute(tx *gorm.DB,
id uint64, id uint64,
isConnected types.NodeConnectedMap, isLikelyConnected *xsync.MapOf[types.NodeID, bool],
) ([]types.NodeID, error) { ) ([]types.NodeID, error) {
route, err := GetRoute(tx, id) route, err := GetRoute(tx, id)
if err != nil { if err != nil {
@@ -147,7 +148,7 @@ func DisableRoute(tx *gorm.DB,
return nil, err return nil, err
} }
update, err = failoverRouteTx(tx, isConnected, route) update, err = failoverRouteTx(tx, isLikelyConnected, route)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -182,17 +183,17 @@ func DisableRoute(tx *gorm.DB,
func (hsdb *HSDatabase) DeleteRoute( func (hsdb *HSDatabase) DeleteRoute(
id uint64, id uint64,
isConnected types.NodeConnectedMap, isLikelyConnected *xsync.MapOf[types.NodeID, bool],
) ([]types.NodeID, error) { ) ([]types.NodeID, error) {
return Write(hsdb.DB, func(tx *gorm.DB) ([]types.NodeID, error) { return Write(hsdb.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
return DeleteRoute(tx, id, isConnected) return DeleteRoute(tx, id, isLikelyConnected)
}) })
} }
func DeleteRoute( func DeleteRoute(
tx *gorm.DB, tx *gorm.DB,
id uint64, id uint64,
isConnected types.NodeConnectedMap, isLikelyConnected *xsync.MapOf[types.NodeID, bool],
) ([]types.NodeID, error) { ) ([]types.NodeID, error) {
route, err := GetRoute(tx, id) route, err := GetRoute(tx, id)
if err != nil { if err != nil {
@@ -207,7 +208,7 @@ func DeleteRoute(
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
var update []types.NodeID var update []types.NodeID
if !route.IsExitRoute() { if !route.IsExitRoute() {
update, err = failoverRouteTx(tx, isConnected, route) update, err = failoverRouteTx(tx, isLikelyConnected, route)
if err != nil { if err != nil {
return nil, nil return nil, nil
} }
@@ -252,7 +253,7 @@ func DeleteRoute(
return update, nil return update, nil
} }
func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isConnected types.NodeConnectedMap) ([]types.NodeID, error) { func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isLikelyConnected *xsync.MapOf[types.NodeID, bool]) ([]types.NodeID, error) {
routes, err := GetNodeRoutes(tx, node) routes, err := GetNodeRoutes(tx, node)
if err != nil { if err != nil {
return nil, fmt.Errorf("getting node routes: %w", err) return nil, fmt.Errorf("getting node routes: %w", err)
@@ -266,7 +267,7 @@ func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isConnected types.NodeConne
// TODO(kradalby): This is a bit too aggressive, we could probably // TODO(kradalby): This is a bit too aggressive, we could probably
// figure out which routes needs to be failed over rather than all. // figure out which routes needs to be failed over rather than all.
chn, err := failoverRouteTx(tx, isConnected, &routes[i]) chn, err := failoverRouteTx(tx, isLikelyConnected, &routes[i])
if err != nil { if err != nil {
return changed, fmt.Errorf("failing over route after delete: %w", err) return changed, fmt.Errorf("failing over route after delete: %w", err)
} }
@@ -409,7 +410,7 @@ func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
// If needed, the failover will be attempted. // If needed, the failover will be attempted.
func FailoverNodeRoutesIfNeccessary( func FailoverNodeRoutesIfNeccessary(
tx *gorm.DB, tx *gorm.DB,
isConnected types.NodeConnectedMap, isLikelyConnected *xsync.MapOf[types.NodeID, bool],
node *types.Node, node *types.Node,
) (*types.StateUpdate, error) { ) (*types.StateUpdate, error) {
nodeRoutes, err := GetNodeRoutes(tx, node) nodeRoutes, err := GetNodeRoutes(tx, node)
@@ -430,12 +431,12 @@ nodeRouteLoop:
if route.IsPrimary { if route.IsPrimary {
// if we have a primary route, and the node is connected // if we have a primary route, and the node is connected
// nothing needs to be done. // nothing needs to be done.
if conn, ok := isConnected[route.Node.ID]; conn && ok { if val, ok := isLikelyConnected.Load(route.Node.ID); ok && val {
continue nodeRouteLoop continue nodeRouteLoop
} }
// if not, we need to failover the route // if not, we need to failover the route
failover := failoverRoute(isConnected, &route, routes) failover := failoverRoute(isLikelyConnected, &route, routes)
if failover != nil { if failover != nil {
err := failover.save(tx) err := failover.save(tx)
if err != nil { if err != nil {
@@ -477,7 +478,7 @@ nodeRouteLoop:
// If the given route was not primary, it returns early. // If the given route was not primary, it returns early.
func failoverRouteTx( func failoverRouteTx(
tx *gorm.DB, tx *gorm.DB,
isConnected types.NodeConnectedMap, isLikelyConnected *xsync.MapOf[types.NodeID, bool],
r *types.Route, r *types.Route,
) ([]types.NodeID, error) { ) ([]types.NodeID, error) {
if r == nil { if r == nil {
@@ -500,7 +501,7 @@ func failoverRouteTx(
return nil, fmt.Errorf("getting routes by prefix: %w", err) return nil, fmt.Errorf("getting routes by prefix: %w", err)
} }
fo := failoverRoute(isConnected, r, routes) fo := failoverRoute(isLikelyConnected, r, routes)
if fo == nil { if fo == nil {
return nil, nil return nil, nil
} }
@@ -538,7 +539,7 @@ func (f *failover) save(tx *gorm.DB) error {
} }
func failoverRoute( func failoverRoute(
isConnected types.NodeConnectedMap, isLikelyConnected *xsync.MapOf[types.NodeID, bool],
routeToReplace *types.Route, routeToReplace *types.Route,
altRoutes types.Routes, altRoutes types.Routes,
@@ -570,9 +571,11 @@ func failoverRoute(
continue continue
} }
if isConnected != nil && isConnected[route.Node.ID] { if isLikelyConnected != nil {
newPrimary = &altRoutes[idx] if val, ok := isLikelyConnected.Load(route.Node.ID); ok && val {
break newPrimary = &altRoutes[idx]
break
}
} }
} }

View File

@@ -10,11 +10,22 @@ import (
"github.com/google/go-cmp/cmp/cmpopts" "github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/puzpuzpuz/xsync/v3"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
) )
var smap = func(m map[types.NodeID]bool) *xsync.MapOf[types.NodeID, bool] {
s := xsync.NewMapOf[types.NodeID, bool]()
for k, v := range m {
s.Store(k, v)
}
return s
}
func (s *Suite) TestGetRoutes(c *check.C) { func (s *Suite) TestGetRoutes(c *check.C) {
user, err := db.CreateUser("test") user, err := db.CreateUser("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
@@ -331,7 +342,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
name string name string
nodes types.Nodes nodes types.Nodes
routes types.Routes routes types.Routes
isConnected []types.NodeConnectedMap isConnected []map[types.NodeID]bool
want []*types.StateUpdate want []*types.StateUpdate
wantErr bool wantErr bool
}{ }{
@@ -346,7 +357,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
r(1, 1, ipp("10.0.0.0/24"), true, true), r(1, 1, ipp("10.0.0.0/24"), true, true),
r(2, 2, ipp("10.0.0.0/24"), true, false), r(2, 2, ipp("10.0.0.0/24"), true, false),
}, },
isConnected: []types.NodeConnectedMap{ isConnected: []map[types.NodeID]bool{
// n1 goes down // n1 goes down
{ {
1: false, 1: false,
@@ -384,7 +395,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
r(1, 1, ipp("10.0.0.0/24"), true, true), r(1, 1, ipp("10.0.0.0/24"), true, true),
r(2, 2, ipp("10.0.0.0/24"), true, false), r(2, 2, ipp("10.0.0.0/24"), true, false),
}, },
isConnected: []types.NodeConnectedMap{ isConnected: []map[types.NodeID]bool{
// n1 up recon = noop // n1 up recon = noop
{ {
1: true, 1: true,
@@ -428,7 +439,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
r(2, 2, ipp("10.0.0.0/24"), true, false), r(2, 2, ipp("10.0.0.0/24"), true, false),
r(3, 3, ipp("10.0.0.0/24"), true, false), r(3, 3, ipp("10.0.0.0/24"), true, false),
}, },
isConnected: []types.NodeConnectedMap{ isConnected: []map[types.NodeID]bool{
// n1 goes down // n1 goes down
{ {
1: false, 1: false,
@@ -486,7 +497,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
r(2, 2, ipp("10.0.0.0/24"), false, false), r(2, 2, ipp("10.0.0.0/24"), false, false),
r(3, 3, ipp("10.0.0.0/24"), true, false), r(3, 3, ipp("10.0.0.0/24"), true, false),
}, },
isConnected: []types.NodeConnectedMap{ isConnected: []map[types.NodeID]bool{
// n1 goes down // n1 goes down
{ {
1: false, 1: false,
@@ -516,7 +527,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
r(2, 2, ipp("10.0.0.0/24"), true, false), r(2, 2, ipp("10.0.0.0/24"), true, false),
r(3, 3, ipp("10.1.0.0/24"), true, false), r(3, 3, ipp("10.1.0.0/24"), true, false),
}, },
isConnected: []types.NodeConnectedMap{ isConnected: []map[types.NodeID]bool{
// n1 goes down // n1 goes down
{ {
1: false, 1: false,
@@ -539,7 +550,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
r(2, 2, ipp("10.0.0.0/24"), true, false), r(2, 2, ipp("10.0.0.0/24"), true, false),
r(3, 3, ipp("10.1.0.0/24"), false, false), r(3, 3, ipp("10.1.0.0/24"), false, false),
}, },
isConnected: []types.NodeConnectedMap{ isConnected: []map[types.NodeID]bool{
// n1 goes down // n1 goes down
{ {
1: false, 1: false,
@@ -562,7 +573,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
r(2, 2, ipp("10.0.0.0/24"), true, false), r(2, 2, ipp("10.0.0.0/24"), true, false),
r(3, 3, ipp("10.1.0.0/24"), true, false), r(3, 3, ipp("10.1.0.0/24"), true, false),
}, },
isConnected: []types.NodeConnectedMap{ isConnected: []map[types.NodeID]bool{
// n1 goes down // n1 goes down
{ {
1: false, 1: false,
@@ -585,7 +596,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
r(2, 2, ipp("10.0.0.0/24"), true, true), r(2, 2, ipp("10.0.0.0/24"), true, true),
r(3, 3, ipp("10.1.0.0/24"), true, false), r(3, 3, ipp("10.1.0.0/24"), true, false),
}, },
isConnected: []types.NodeConnectedMap{ isConnected: []map[types.NodeID]bool{
// n1 goes down // n1 goes down
{ {
1: true, 1: true,
@@ -618,7 +629,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) {
want := tt.want[step] want := tt.want[step]
got, err := Write(db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { got, err := Write(db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return FailoverNodeRoutesIfNeccessary(tx, isConnected, node) return FailoverNodeRoutesIfNeccessary(tx, smap(isConnected), node)
}) })
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
@@ -640,7 +651,7 @@ func TestFailoverRouteTx(t *testing.T) {
name string name string
failingRoute types.Route failingRoute types.Route
routes types.Routes routes types.Routes
isConnected types.NodeConnectedMap isConnected map[types.NodeID]bool
want []types.NodeID want []types.NodeID
wantErr bool wantErr bool
}{ }{
@@ -743,7 +754,7 @@ func TestFailoverRouteTx(t *testing.T) {
Enabled: true, Enabled: true,
}, },
}, },
isConnected: types.NodeConnectedMap{ isConnected: map[types.NodeID]bool{
1: false, 1: false,
2: true, 2: true,
}, },
@@ -841,7 +852,7 @@ func TestFailoverRouteTx(t *testing.T) {
Enabled: true, Enabled: true,
}, },
}, },
isConnected: types.NodeConnectedMap{ isConnected: map[types.NodeID]bool{
1: true, 1: true,
2: true, 2: true,
3: true, 3: true,
@@ -889,7 +900,7 @@ func TestFailoverRouteTx(t *testing.T) {
Enabled: true, Enabled: true,
}, },
}, },
isConnected: types.NodeConnectedMap{ isConnected: map[types.NodeID]bool{
1: true, 1: true,
4: false, 4: false,
}, },
@@ -945,7 +956,7 @@ func TestFailoverRouteTx(t *testing.T) {
Enabled: true, Enabled: true,
}, },
}, },
isConnected: types.NodeConnectedMap{ isConnected: map[types.NodeID]bool{
1: false, 1: false,
2: true, 2: true,
4: false, 4: false,
@@ -1010,7 +1021,7 @@ func TestFailoverRouteTx(t *testing.T) {
} }
got, err := Write(db.DB, func(tx *gorm.DB) ([]types.NodeID, error) { got, err := Write(db.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
return failoverRouteTx(tx, tt.isConnected, &tt.failingRoute) return failoverRouteTx(tx, smap(tt.isConnected), &tt.failingRoute)
}) })
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
@@ -1048,7 +1059,7 @@ func TestFailoverRoute(t *testing.T) {
name string name string
failingRoute types.Route failingRoute types.Route
routes types.Routes routes types.Routes
isConnected types.NodeConnectedMap isConnected map[types.NodeID]bool
want *failover want *failover
}{ }{
{ {
@@ -1085,7 +1096,7 @@ func TestFailoverRoute(t *testing.T) {
r(1, 1, ipp("10.0.0.0/24"), true, true), r(1, 1, ipp("10.0.0.0/24"), true, true),
r(2, 2, ipp("10.0.0.0/24"), true, false), r(2, 2, ipp("10.0.0.0/24"), true, false),
}, },
isConnected: types.NodeConnectedMap{ isConnected: map[types.NodeID]bool{
1: false, 1: false,
2: true, 2: true,
}, },
@@ -1111,7 +1122,7 @@ func TestFailoverRoute(t *testing.T) {
r(2, 2, ipp("10.0.0.0/24"), true, true), r(2, 2, ipp("10.0.0.0/24"), true, true),
r(3, 3, ipp("10.0.0.0/24"), true, false), r(3, 3, ipp("10.0.0.0/24"), true, false),
}, },
isConnected: types.NodeConnectedMap{ isConnected: map[types.NodeID]bool{
1: true, 1: true,
2: true, 2: true,
3: true, 3: true,
@@ -1128,7 +1139,7 @@ func TestFailoverRoute(t *testing.T) {
r(1, 1, ipp("10.0.0.0/24"), true, true), r(1, 1, ipp("10.0.0.0/24"), true, true),
r(2, 4, ipp("10.0.0.0/24"), true, false), r(2, 4, ipp("10.0.0.0/24"), true, false),
}, },
isConnected: types.NodeConnectedMap{ isConnected: map[types.NodeID]bool{
1: true, 1: true,
4: false, 4: false,
}, },
@@ -1142,7 +1153,7 @@ func TestFailoverRoute(t *testing.T) {
r(2, 4, ipp("10.0.0.0/24"), true, false), r(2, 4, ipp("10.0.0.0/24"), true, false),
r(3, 2, ipp("10.0.0.0/24"), true, false), r(3, 2, ipp("10.0.0.0/24"), true, false),
}, },
isConnected: types.NodeConnectedMap{ isConnected: map[types.NodeID]bool{
1: false, 1: false,
2: true, 2: true,
4: false, 4: false,
@@ -1172,7 +1183,7 @@ func TestFailoverRoute(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
gotf := failoverRoute(tt.isConnected, &tt.failingRoute, tt.routes) gotf := failoverRoute(smap(tt.isConnected), &tt.failingRoute, tt.routes)
if tt.want == nil && gotf != nil { if tt.want == nil && gotf != nil {
t.Fatalf("expected nil, got %+v", gotf) t.Fatalf("expected nil, got %+v", gotf)

View File

@@ -4,7 +4,6 @@ package hscontrol
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"sort" "sort"
"strings" "strings"
"time" "time"
@@ -145,7 +144,7 @@ func (api headscaleV1APIServer) ExpirePreAuthKey(
ctx context.Context, ctx context.Context,
request *v1.ExpirePreAuthKeyRequest, request *v1.ExpirePreAuthKeyRequest,
) (*v1.ExpirePreAuthKeyResponse, error) { ) (*v1.ExpirePreAuthKeyResponse, error) {
err := api.h.db.DB.Transaction(func(tx *gorm.DB) error { err := api.h.db.Write(func(tx *gorm.DB) error {
preAuthKey, err := db.GetPreAuthKey(tx, request.GetUser(), request.Key) preAuthKey, err := db.GetPreAuthKey(tx, request.GetUser(), request.Key)
if err != nil { if err != nil {
return err return err
@@ -279,13 +278,13 @@ func (api headscaleV1APIServer) SetTags(
func validateTag(tag string) error { func validateTag(tag string) error {
if strings.Index(tag, "tag:") != 0 { if strings.Index(tag, "tag:") != 0 {
return fmt.Errorf("tag must start with the string 'tag:'") return errors.New("tag must start with the string 'tag:'")
} }
if strings.ToLower(tag) != tag { if strings.ToLower(tag) != tag {
return fmt.Errorf("tag should be lowercase") return errors.New("tag should be lowercase")
} }
if len(strings.Fields(tag)) > 1 { if len(strings.Fields(tag)) > 1 {
return fmt.Errorf("tag should not contains space") return errors.New("tag should not contains space")
} }
return nil return nil
} }
@@ -301,7 +300,7 @@ func (api headscaleV1APIServer) DeleteNode(
changedNodes, err := api.h.db.DeleteNode( changedNodes, err := api.h.db.DeleteNode(
node, node,
api.h.nodeNotifier.ConnectedMap(), api.h.nodeNotifier.LikelyConnectedMap(),
) )
if err != nil { if err != nil {
return nil, err return nil, err
@@ -343,7 +342,7 @@ func (api headscaleV1APIServer) ExpireNode(
} }
ctx = types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname) ctx = types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
api.h.nodeNotifier.NotifyByMachineKey( api.h.nodeNotifier.NotifyByNodeID(
ctx, ctx,
types.StateUpdate{ types.StateUpdate{
Type: types.StateSelfUpdate, Type: types.StateSelfUpdate,
@@ -401,7 +400,7 @@ func (api headscaleV1APIServer) ListNodes(
ctx context.Context, ctx context.Context,
request *v1.ListNodesRequest, request *v1.ListNodesRequest,
) (*v1.ListNodesResponse, error) { ) (*v1.ListNodesResponse, error) {
isConnected := api.h.nodeNotifier.ConnectedMap() isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap()
if request.GetUser() != "" { if request.GetUser() != "" {
nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) { nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) {
return db.ListNodesByUser(rx, request.GetUser()) return db.ListNodesByUser(rx, request.GetUser())
@@ -416,7 +415,9 @@ func (api headscaleV1APIServer) ListNodes(
// Populate the online field based on // Populate the online field based on
// currently connected nodes. // currently connected nodes.
resp.Online = isConnected[node.ID] if val, ok := isLikelyConnected.Load(node.ID); ok && val {
resp.Online = true
}
response[index] = resp response[index] = resp
} }
@@ -439,7 +440,9 @@ func (api headscaleV1APIServer) ListNodes(
// Populate the online field based on // Populate the online field based on
// currently connected nodes. // currently connected nodes.
resp.Online = isConnected[node.ID] if val, ok := isLikelyConnected.Load(node.ID); ok && val {
resp.Online = true
}
validTags, invalidTags := api.h.ACLPolicy.TagsOfNode( validTags, invalidTags := api.h.ACLPolicy.TagsOfNode(
node, node,
@@ -528,7 +531,7 @@ func (api headscaleV1APIServer) DisableRoute(
request *v1.DisableRouteRequest, request *v1.DisableRouteRequest,
) (*v1.DisableRouteResponse, error) { ) (*v1.DisableRouteResponse, error) {
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) ([]types.NodeID, error) { update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
return db.DisableRoute(tx, request.GetRouteId(), api.h.nodeNotifier.ConnectedMap()) return db.DisableRoute(tx, request.GetRouteId(), api.h.nodeNotifier.LikelyConnectedMap())
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -568,7 +571,7 @@ func (api headscaleV1APIServer) DeleteRoute(
ctx context.Context, ctx context.Context,
request *v1.DeleteRouteRequest, request *v1.DeleteRouteRequest,
) (*v1.DeleteRouteResponse, error) { ) (*v1.DeleteRouteResponse, error) {
isConnected := api.h.nodeNotifier.ConnectedMap() isConnected := api.h.nodeNotifier.LikelyConnectedMap()
update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) ([]types.NodeID, error) { update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) ([]types.NodeID, error) {
return db.DeleteRoute(tx, request.GetRouteId(), isConnected) return db.DeleteRoute(tx, request.GetRouteId(), isConnected)
}) })

View File

@@ -17,6 +17,7 @@ import (
mapset "github.com/deckarep/golang-set/v2" mapset "github.com/deckarep/golang-set/v2"
"github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
@@ -51,10 +52,10 @@ var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_
type Mapper struct { type Mapper struct {
// Configuration // Configuration
// TODO(kradalby): figure out if this is the format we want this in // TODO(kradalby): figure out if this is the format we want this in
db *db.HSDatabase db *db.HSDatabase
cfg *types.Config cfg *types.Config
derpMap *tailcfg.DERPMap derpMap *tailcfg.DERPMap
isLikelyConnected types.NodeConnectedMap notif *notifier.Notifier
uid string uid string
created time.Time created time.Time
@@ -70,15 +71,15 @@ func NewMapper(
db *db.HSDatabase, db *db.HSDatabase,
cfg *types.Config, cfg *types.Config,
derpMap *tailcfg.DERPMap, derpMap *tailcfg.DERPMap,
isLikelyConnected types.NodeConnectedMap, notif *notifier.Notifier,
) *Mapper { ) *Mapper {
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength) uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
return &Mapper{ return &Mapper{
db: db, db: db,
cfg: cfg, cfg: cfg,
derpMap: derpMap, derpMap: derpMap,
isLikelyConnected: isLikelyConnected, notif: notif,
uid: uid, uid: uid,
created: time.Now(), created: time.Now(),
@@ -517,7 +518,7 @@ func (m *Mapper) ListPeers(nodeID types.NodeID) (types.Nodes, error) {
} }
for _, peer := range peers { for _, peer := range peers {
online := m.isLikelyConnected[peer.ID] online := m.notif.IsLikelyConnected(peer.ID)
peer.IsOnline = &online peer.IsOnline = &online
} }

View File

@@ -1,6 +1,10 @@
package hscontrol package hscontrol
import ( import (
"net/http"
"strconv"
"github.com/gorilla/mux"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promauto"
) )
@@ -8,18 +12,94 @@ import (
const prometheusNamespace = "headscale" const prometheusNamespace = "headscale"
var ( var (
// This is a high cardinality metric (user x node), we might want to make this mapResponseSent = promauto.NewCounterVec(prometheus.CounterOpts{
// configurable/opt-in in the future.
nodeRegistrations = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace, Namespace: prometheusNamespace,
Name: "node_registrations_total", Name: "mapresponse_sent_total",
Help: "The total amount of registered node attempts", Help: "total count of mapresponses sent to clients",
}, []string{"action", "auth", "status", "user"}) }, []string{"status", "type"})
mapResponseUpdateReceived = promauto.NewCounterVec(prometheus.CounterOpts{
updateRequestsSentToNode = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace, Namespace: prometheusNamespace,
Name: "update_request_sent_to_node_total", Name: "mapresponse_updates_received_total",
Help: "The number of calls/messages issued on a specific nodes update channel", Help: "total count of mapresponse updates received on update channel",
}, []string{"user", "node", "status"}) }, []string{"type"})
// TODO(kradalby): This is very debugging, we might want to remove it. mapResponseWriteUpdatesInStream = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "mapresponse_write_updates_in_stream_total",
Help: "total count of writes that occured in a stream session, pre-68 nodes",
}, []string{"status"})
mapResponseEndpointUpdates = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "mapresponse_endpoint_updates_total",
Help: "total count of endpoint updates received",
}, []string{"status"})
mapResponseReadOnly = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "mapresponse_readonly_requests_total",
Help: "total count of readonly requests received",
}, []string{"status"})
mapResponseSessions = promauto.NewGauge(prometheus.GaugeOpts{
Namespace: prometheusNamespace,
Name: "mapresponse_current_sessions_total",
Help: "total count open map response sessions",
})
mapResponseRejected = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "mapresponse_rejected_new_sessions_total",
Help: "total count of new mapsessions rejected",
}, []string{"reason"})
httpDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{
Namespace: prometheusNamespace,
Name: "http_duration_seconds",
Help: "Duration of HTTP requests.",
}, []string{"path"})
httpCounter = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "http_requests_total",
Help: "Total number of http requests processed",
}, []string{"code", "method", "path"},
)
) )
// prometheusMiddleware implements mux.MiddlewareFunc.
func prometheusMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
route := mux.CurrentRoute(r)
path, _ := route.GetPathTemplate()
// Ignore streaming and noise sessions
// it has its own router further down.
if path == "/ts2021" || path == "/machine/map" || path == "/derp" || path == "/derp/probe" || path == "/bootstrap-dns" {
next.ServeHTTP(w, r)
return
}
rw := &respWriterProm{ResponseWriter: w}
timer := prometheus.NewTimer(httpDuration.WithLabelValues(path))
next.ServeHTTP(rw, r)
timer.ObserveDuration()
httpCounter.WithLabelValues(strconv.Itoa(rw.status), r.Method, path).Inc()
})
}
type respWriterProm struct {
http.ResponseWriter
status int
written int64
wroteHeader bool
}
func (r *respWriterProm) WriteHeader(code int) {
r.status = code
r.wroteHeader = true
r.ResponseWriter.WriteHeader(code)
}
func (r *respWriterProm) Write(b []byte) (int, error) {
if !r.wroteHeader {
r.WriteHeader(http.StatusOK)
}
n, err := r.ResponseWriter.Write(b)
r.written += int64(n)
return n, err
}

View File

@@ -95,6 +95,7 @@ func (h *Headscale) NoiseUpgradeHandler(
// The HTTP2 server that exposes this router is created for // The HTTP2 server that exposes this router is created for
// a single hijacked connection from /ts2021, using netutil.NewOneConnListener // a single hijacked connection from /ts2021, using netutil.NewOneConnListener
router := mux.NewRouter() router := mux.NewRouter()
router.Use(prometheusMiddleware)
router.HandleFunc("/machine/register", noiseServer.NoiseRegistrationHandler). router.HandleFunc("/machine/register", noiseServer.NoiseRegistrationHandler).
Methods(http.MethodPost) Methods(http.MethodPost)
@@ -267,10 +268,12 @@ func (ns *noiseServer) NoisePollNetMapHandler(
defer ns.headscale.mapSessionMu.Unlock() defer ns.headscale.mapSessionMu.Unlock()
sess.infof("node has an open stream(%p), rejecting new stream", sess) sess.infof("node has an open stream(%p), rejecting new stream", sess)
mapResponseRejected.WithLabelValues("exists").Inc()
return return
} }
ns.headscale.mapSessions[node.ID] = sess ns.headscale.mapSessions[node.ID] = sess
mapResponseSessions.Inc()
ns.headscale.mapSessionMu.Unlock() ns.headscale.mapSessionMu.Unlock()
sess.tracef("releasing lock to check stream") sess.tracef("releasing lock to check stream")
} }
@@ -283,6 +286,7 @@ func (ns *noiseServer) NoisePollNetMapHandler(
defer ns.headscale.mapSessionMu.Unlock() defer ns.headscale.mapSessionMu.Unlock()
delete(ns.headscale.mapSessions, node.ID) delete(ns.headscale.mapSessions, node.ID)
mapResponseSessions.Dec()
sess.tracef("releasing lock to remove stream") sess.tracef("releasing lock to remove stream")
} }

View File

@@ -0,0 +1,27 @@
package notifier
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
const prometheusNamespace = "headscale"
var (
notifierWaitForLock = promauto.NewHistogramVec(prometheus.HistogramOpts{
Namespace: prometheusNamespace,
Name: "notifier_wait_for_lock_seconds",
Help: "histogram of time spent waiting for the notifier lock",
Buckets: []float64{0.001, 0.01, 0.1, 0.3, 0.5, 1, 3, 5, 10},
}, []string{"action"})
notifierUpdateSent = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "notifier_update_sent_total",
Help: "total count of update sent on nodes channel",
}, []string{"status", "type"})
notifierNodeUpdateChans = promauto.NewGauge(prometheus.GaugeOpts{
Namespace: prometheusNamespace,
Name: "notifier_open_channels_total",
Help: "total count open channels in notifier",
})
)

View File

@@ -6,21 +6,23 @@ import (
"slices" "slices"
"strings" "strings"
"sync" "sync"
"time"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/puzpuzpuz/xsync/v3"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
type Notifier struct { type Notifier struct {
l sync.RWMutex l sync.RWMutex
nodes map[types.NodeID]chan<- types.StateUpdate nodes map[types.NodeID]chan<- types.StateUpdate
connected types.NodeConnectedMap connected *xsync.MapOf[types.NodeID, bool]
} }
func NewNotifier() *Notifier { func NewNotifier() *Notifier {
return &Notifier{ return &Notifier{
nodes: make(map[types.NodeID]chan<- types.StateUpdate), nodes: make(map[types.NodeID]chan<- types.StateUpdate),
connected: make(types.NodeConnectedMap), connected: xsync.NewMapOf[types.NodeID, bool](),
} }
} }
@@ -31,16 +33,19 @@ func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
Uint64("node.id", nodeID.Uint64()). Uint64("node.id", nodeID.Uint64()).
Msg("releasing lock to add node") Msg("releasing lock to add node")
start := time.Now()
n.l.Lock() n.l.Lock()
defer n.l.Unlock() defer n.l.Unlock()
notifierWaitForLock.WithLabelValues("add").Observe(time.Since(start).Seconds())
n.nodes[nodeID] = c n.nodes[nodeID] = c
n.connected[nodeID] = true n.connected.Store(nodeID, true)
log.Trace(). log.Trace().
Uint64("node.id", nodeID.Uint64()). Uint64("node.id", nodeID.Uint64()).
Int("open_chans", len(n.nodes)). Int("open_chans", len(n.nodes)).
Msg("Added new channel") Msg("Added new channel")
notifierNodeUpdateChans.Inc()
} }
func (n *Notifier) RemoveNode(nodeID types.NodeID) { func (n *Notifier) RemoveNode(nodeID types.NodeID) {
@@ -50,20 +55,23 @@ func (n *Notifier) RemoveNode(nodeID types.NodeID) {
Uint64("node.id", nodeID.Uint64()). Uint64("node.id", nodeID.Uint64()).
Msg("releasing lock to remove node") Msg("releasing lock to remove node")
start := time.Now()
n.l.Lock() n.l.Lock()
defer n.l.Unlock() defer n.l.Unlock()
notifierWaitForLock.WithLabelValues("remove").Observe(time.Since(start).Seconds())
if len(n.nodes) == 0 { if len(n.nodes) == 0 {
return return
} }
delete(n.nodes, nodeID) delete(n.nodes, nodeID)
n.connected[nodeID] = false n.connected.Store(nodeID, false)
log.Trace(). log.Trace().
Uint64("node.id", nodeID.Uint64()). Uint64("node.id", nodeID.Uint64()).
Int("open_chans", len(n.nodes)). Int("open_chans", len(n.nodes)).
Msg("Removed channel") Msg("Removed channel")
notifierNodeUpdateChans.Dec()
} }
// IsConnected reports if a node is connected to headscale and has a // IsConnected reports if a node is connected to headscale and has a
@@ -72,17 +80,22 @@ func (n *Notifier) IsConnected(nodeID types.NodeID) bool {
n.l.RLock() n.l.RLock()
defer n.l.RUnlock() defer n.l.RUnlock()
return n.connected[nodeID] if val, ok := n.connected.Load(nodeID); ok {
return val
}
return false
} }
// IsLikelyConnected reports if a node is connected to headscale and has a // IsLikelyConnected reports if a node is connected to headscale and has a
// poll session open, but doesnt lock, so might be wrong. // poll session open, but doesnt lock, so might be wrong.
func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool { func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
return n.connected[nodeID] if val, ok := n.connected.Load(nodeID); ok {
return val
}
return false
} }
// TODO(kradalby): This returns a pointer and can be dangerous. func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] {
func (n *Notifier) ConnectedMap() types.NodeConnectedMap {
return n.connected return n.connected
} }
@@ -95,45 +108,16 @@ func (n *Notifier) NotifyWithIgnore(
update types.StateUpdate, update types.StateUpdate,
ignoreNodeIDs ...types.NodeID, ignoreNodeIDs ...types.NodeID,
) { ) {
log.Trace().Caller().Str("type", update.Type.String()).Msg("acquiring lock to notify") for nodeID := range n.nodes {
defer log.Trace().
Caller().
Str("type", update.Type.String()).
Msg("releasing lock, finished notifying")
n.l.RLock()
defer n.l.RUnlock()
if update.Type == types.StatePeerChangedPatch {
log.Trace().Interface("update", update).Interface("online", n.connected).Msg("PATCH UPDATE SENT")
}
for nodeID, c := range n.nodes {
if slices.Contains(ignoreNodeIDs, nodeID) { if slices.Contains(ignoreNodeIDs, nodeID) {
continue continue
} }
select { n.NotifyByNodeID(ctx, update, nodeID)
case <-ctx.Done():
log.Error().
Err(ctx.Err()).
Uint64("node.id", nodeID.Uint64()).
Any("origin", ctx.Value("origin")).
Any("origin-hostname", ctx.Value("hostname")).
Msgf("update not sent, context cancelled")
return
case c <- update:
log.Trace().
Uint64("node.id", nodeID.Uint64()).
Any("origin", ctx.Value("origin")).
Any("origin-hostname", ctx.Value("hostname")).
Msgf("update successfully sent on chan")
}
} }
} }
func (n *Notifier) NotifyByMachineKey( func (n *Notifier) NotifyByNodeID(
ctx context.Context, ctx context.Context,
update types.StateUpdate, update types.StateUpdate,
nodeID types.NodeID, nodeID types.NodeID,
@@ -144,8 +128,10 @@ func (n *Notifier) NotifyByMachineKey(
Str("type", update.Type.String()). Str("type", update.Type.String()).
Msg("releasing lock, finished notifying") Msg("releasing lock, finished notifying")
start := time.Now()
n.l.RLock() n.l.RLock()
defer n.l.RUnlock() defer n.l.RUnlock()
notifierWaitForLock.WithLabelValues("notify").Observe(time.Since(start).Seconds())
if c, ok := n.nodes[nodeID]; ok { if c, ok := n.nodes[nodeID]; ok {
select { select {
@@ -156,6 +142,7 @@ func (n *Notifier) NotifyByMachineKey(
Any("origin", ctx.Value("origin")). Any("origin", ctx.Value("origin")).
Any("origin-hostname", ctx.Value("hostname")). Any("origin-hostname", ctx.Value("hostname")).
Msgf("update not sent, context cancelled") Msgf("update not sent, context cancelled")
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String()).Inc()
return return
case c <- update: case c <- update:
@@ -164,6 +151,7 @@ func (n *Notifier) NotifyByMachineKey(
Any("origin", ctx.Value("origin")). Any("origin", ctx.Value("origin")).
Any("origin-hostname", ctx.Value("hostname")). Any("origin-hostname", ctx.Value("hostname")).
Msgf("update successfully sent on chan") Msgf("update successfully sent on chan")
notifierUpdateSent.WithLabelValues("ok", update.Type.String()).Inc()
} }
} }
} }
@@ -182,9 +170,10 @@ func (n *Notifier) String() string {
b.WriteString("\n") b.WriteString("\n")
b.WriteString("connected:\n") b.WriteString("connected:\n")
for k, v := range n.connected { n.connected.Range(func(k types.NodeID, v bool) bool {
fmt.Fprintf(&b, "\t%d: %t\n", k, v) fmt.Fprintf(&b, "\t%d: %t\n", k, v)
} return true
})
return b.String() return b.String()
} }

View File

@@ -602,7 +602,7 @@ func (h *Headscale) registerNodeForOIDCCallback(
return err return err
} }
if err := h.db.DB.Transaction(func(tx *gorm.DB) error { if err := h.db.Write(func(tx *gorm.DB) error {
if _, err := db.RegisterNodeFromAuthCallback( if _, err := db.RegisterNodeFromAuthCallback(
// TODO(kradalby): find a better way to use the cache across modules // TODO(kradalby): find a better way to use the cache across modules
tx, tx,

View File

@@ -64,7 +64,7 @@ func (h *Headscale) newMapSession(
w http.ResponseWriter, w http.ResponseWriter,
node *types.Node, node *types.Node,
) *mapSession { ) *mapSession {
warnf, tracef, infof, errf := logPollFunc(req, node) warnf, infof, tracef, errf := logPollFunc(req, node)
// Use a buffered channel in case a node is not fully ready // Use a buffered channel in case a node is not fully ready
// to receive a message to make sure we dont block the entire // to receive a message to make sure we dont block the entire
@@ -196,8 +196,10 @@ func (m *mapSession) serve() {
// return // return
err := m.handleSaveNode() err := m.handleSaveNode()
if err != nil { if err != nil {
mapResponseWriteUpdatesInStream.WithLabelValues("error").Inc()
return return
} }
mapResponseWriteUpdatesInStream.WithLabelValues("ok").Inc()
} }
// Set up the client stream // Set up the client stream
@@ -284,6 +286,7 @@ func (m *mapSession) serve() {
patches = filteredPatches patches = filteredPatches
} }
updateType := "full"
// When deciding what update to send, the following is considered, // When deciding what update to send, the following is considered,
// Full is a superset of all updates, when a full update is requested, // Full is a superset of all updates, when a full update is requested,
// send only that and move on, all other updates will be present in // send only that and move on, all other updates will be present in
@@ -303,12 +306,15 @@ func (m *mapSession) serve() {
} else if changed != nil { } else if changed != nil {
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, patches, m.h.ACLPolicy, lastMessage) data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, patches, m.h.ACLPolicy, lastMessage)
updateType = "change"
} else if patches != nil { } else if patches != nil {
m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage)) m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, patches, m.h.ACLPolicy) data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, patches, m.h.ACLPolicy)
updateType = "patch"
} else if derp { } else if derp {
m.tracef("Sending DERPUpdate MapResponse") m.tracef("Sending DERPUpdate MapResponse")
data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.DERPMap) data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.DERPMap)
updateType = "derp"
} }
if err != nil { if err != nil {
@@ -324,19 +330,22 @@ func (m *mapSession) serve() {
startWrite := time.Now() startWrite := time.Now()
_, err = m.w.Write(data) _, err = m.w.Write(data)
if err != nil { if err != nil {
mapResponseSent.WithLabelValues("error", updateType).Inc()
m.errf(err, "Could not write the map response, for mapSession: %p", m) m.errf(err, "Could not write the map response, for mapSession: %p", m)
return return
} }
err = rc.Flush() err = rc.Flush()
if err != nil { if err != nil {
mapResponseSent.WithLabelValues("error", updateType).Inc()
m.errf(err, "flushing the map response to client, for mapSession: %p", m) m.errf(err, "flushing the map response to client, for mapSession: %p", m)
return return
} }
log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node") log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node")
m.infof("update sent") mapResponseSent.WithLabelValues("ok", updateType).Inc()
m.tracef("update sent")
} }
// reset // reset
@@ -364,7 +373,8 @@ func (m *mapSession) serve() {
// Consume all updates sent to node // Consume all updates sent to node
case update := <-m.ch: case update := <-m.ch:
m.tracef("received stream update: %d %s", update.Type, update.Message) m.tracef("received stream update: %s %s", update.Type.String(), update.Message)
mapResponseUpdateReceived.WithLabelValues(update.Type.String()).Inc()
switch update.Type { switch update.Type {
case types.StateFullUpdate: case types.StateFullUpdate:
@@ -404,27 +414,30 @@ func (m *mapSession) serve() {
data, err := m.mapper.KeepAliveResponse(m.req, m.node) data, err := m.mapper.KeepAliveResponse(m.req, m.node)
if err != nil { if err != nil {
m.errf(err, "Error generating the keep alive msg") m.errf(err, "Error generating the keep alive msg")
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
return return
} }
_, err = m.w.Write(data) _, err = m.w.Write(data)
if err != nil { if err != nil {
m.errf(err, "Cannot write keep alive message") m.errf(err, "Cannot write keep alive message")
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
return return
} }
err = rc.Flush() err = rc.Flush()
if err != nil { if err != nil {
m.errf(err, "flushing keep alive to client, for mapSession: %p", m) m.errf(err, "flushing keep alive to client, for mapSession: %p", m)
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
return return
} }
mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
} }
} }
} }
func (m *mapSession) pollFailoverRoutes(where string, node *types.Node) { func (m *mapSession) pollFailoverRoutes(where string, node *types.Node) {
update, err := db.Write(m.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { update, err := db.Write(m.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) {
return db.FailoverNodeRoutesIfNeccessary(tx, m.h.nodeNotifier.ConnectedMap(), node) return db.FailoverNodeRoutesIfNeccessary(tx, m.h.nodeNotifier.LikelyConnectedMap(), node)
}) })
if err != nil { if err != nil {
m.errf(err, fmt.Sprintf("failed to ensure failover routes, %s", where)) m.errf(err, fmt.Sprintf("failed to ensure failover routes, %s", where))
@@ -454,7 +467,7 @@ func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) {
node.LastSeen = &now node.LastSeen = &now
change.LastSeen = &now change.LastSeen = &now
err := h.db.DB.Transaction(func(tx *gorm.DB) error { err := h.db.Write(func(tx *gorm.DB) error {
return db.SetLastSeen(tx, node.ID, *node.LastSeen) return db.SetLastSeen(tx, node.ID, *node.LastSeen)
}) })
if err != nil { if err != nil {
@@ -501,6 +514,7 @@ func (m *mapSession) handleEndpointUpdate() {
// If there is no changes and nothing to save, // If there is no changes and nothing to save,
// return early. // return early.
if peerChangeEmpty(change) && !sendUpdate { if peerChangeEmpty(change) && !sendUpdate {
mapResponseEndpointUpdates.WithLabelValues("noop").Inc()
return return
} }
@@ -518,6 +532,7 @@ func (m *mapSession) handleEndpointUpdate() {
if err != nil { if err != nil {
m.errf(err, "Error processing node routes") m.errf(err, "Error processing node routes")
http.Error(m.w, "", http.StatusInternalServerError) http.Error(m.w, "", http.StatusInternalServerError)
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
return return
} }
@@ -527,6 +542,7 @@ func (m *mapSession) handleEndpointUpdate() {
err := m.h.db.EnableAutoApprovedRoutes(m.h.ACLPolicy, m.node) err := m.h.db.EnableAutoApprovedRoutes(m.h.ACLPolicy, m.node)
if err != nil { if err != nil {
m.errf(err, "Error running auto approved routes") m.errf(err, "Error running auto approved routes")
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
} }
} }
@@ -534,19 +550,19 @@ func (m *mapSession) handleEndpointUpdate() {
// has an updated packetfilter allowing the new route // has an updated packetfilter allowing the new route
// if it is defined in the ACL. // if it is defined in the ACL.
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", m.node.Hostname) ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", m.node.Hostname)
m.h.nodeNotifier.NotifyByMachineKey( m.h.nodeNotifier.NotifyByNodeID(
ctx, ctx,
types.StateUpdate{ types.StateUpdate{
Type: types.StateSelfUpdate, Type: types.StateSelfUpdate,
ChangeNodes: []types.NodeID{m.node.ID}, ChangeNodes: []types.NodeID{m.node.ID},
}, },
m.node.ID) m.node.ID)
} }
if err := m.h.db.DB.Save(m.node).Error; err != nil { if err := m.h.db.DB.Save(m.node).Error; err != nil {
m.errf(err, "Failed to persist/update node in the database") m.errf(err, "Failed to persist/update node in the database")
http.Error(m.w, "", http.StatusInternalServerError) http.Error(m.w, "", http.StatusInternalServerError)
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
return return
} }
@@ -562,6 +578,7 @@ func (m *mapSession) handleEndpointUpdate() {
m.node.ID) m.node.ID)
m.w.WriteHeader(http.StatusOK) m.w.WriteHeader(http.StatusOK)
mapResponseEndpointUpdates.WithLabelValues("ok").Inc()
return return
} }
@@ -639,7 +656,7 @@ func (m *mapSession) handleReadOnlyRequest() {
if err != nil { if err != nil {
m.errf(err, "Failed to create MapResponse") m.errf(err, "Failed to create MapResponse")
http.Error(m.w, "", http.StatusInternalServerError) http.Error(m.w, "", http.StatusInternalServerError)
mapResponseReadOnly.WithLabelValues("error").Inc()
return return
} }
@@ -648,9 +665,12 @@ func (m *mapSession) handleReadOnlyRequest() {
_, err = m.w.Write(mapResp) _, err = m.w.Write(mapResp)
if err != nil { if err != nil {
m.errf(err, "Failed to write response") m.errf(err, "Failed to write response")
mapResponseReadOnly.WithLabelValues("error").Inc()
return
} }
m.w.WriteHeader(http.StatusOK) m.w.WriteHeader(http.StatusOK)
mapResponseReadOnly.WithLabelValues("ok").Inc()
return return
} }

View File

@@ -46,7 +46,6 @@ type Config struct {
GRPCAddr string GRPCAddr string
GRPCAllowInsecure bool GRPCAllowInsecure bool
EphemeralNodeInactivityTimeout time.Duration EphemeralNodeInactivityTimeout time.Duration
NodeUpdateCheckInterval time.Duration
PrefixV4 *netip.Prefix PrefixV4 *netip.Prefix
PrefixV6 *netip.Prefix PrefixV6 *netip.Prefix
IPAllocation IPAllocationStrategy IPAllocation IPAllocationStrategy
@@ -233,12 +232,10 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("ephemeral_node_inactivity_timeout", "120s") viper.SetDefault("ephemeral_node_inactivity_timeout", "120s")
viper.SetDefault("node_update_check_interval", "10s")
viper.SetDefault("tuning.batch_change_delay", "800ms") viper.SetDefault("tuning.batch_change_delay", "800ms")
viper.SetDefault("tuning.node_mapsession_buffered_chan_size", 30) viper.SetDefault("tuning.node_mapsession_buffered_chan_size", 30)
viper.SetDefault("prefixes.allocation", IPAllocationStrategySequential) viper.SetDefault("prefixes.allocation", string(IPAllocationStrategySequential))
if IsCLIConfigured() { if IsCLIConfigured() {
return nil return nil
@@ -290,15 +287,6 @@ func LoadConfig(path string, isFile bool) error {
) )
} }
maxNodeUpdateCheckInterval, _ := time.ParseDuration("60s")
if viper.GetDuration("node_update_check_interval") > maxNodeUpdateCheckInterval {
errorText += fmt.Sprintf(
"Fatal config error: node_update_check_interval (%s) is set too high, must be less than %s",
viper.GetString("node_update_check_interval"),
maxNodeUpdateCheckInterval,
)
}
if errorText != "" { if errorText != "" {
// nolint // nolint
return errors.New(strings.TrimSuffix(errorText, "\n")) return errors.New(strings.TrimSuffix(errorText, "\n"))
@@ -714,10 +702,6 @@ func GetHeadscaleConfig() (*Config, error) {
"ephemeral_node_inactivity_timeout", "ephemeral_node_inactivity_timeout",
), ),
NodeUpdateCheckInterval: viper.GetDuration(
"node_update_check_interval",
),
Database: GetDatabaseConfig(), Database: GetDatabaseConfig(),
TLS: GetTLSConfig(), TLS: GetTLSConfig(),

View File

@@ -28,7 +28,8 @@ var (
) )
type NodeID uint64 type NodeID uint64
type NodeConnectedMap map[NodeID]bool
// type NodeConnectedMap *xsync.MapOf[NodeID, bool]
func (id NodeID) StableID() tailcfg.StableNodeID { func (id NodeID) StableID() tailcfg.StableNodeID {
return tailcfg.StableNodeID(strconv.FormatUint(uint64(id), util.Base10)) return tailcfg.StableNodeID(strconv.FormatUint(uint64(id), util.Base10))

View File

@@ -51,7 +51,7 @@ func aclScenario(
clientsPerUser int, clientsPerUser int,
) *Scenario { ) *Scenario {
t.Helper() t.Helper()
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
spec := map[string]int{ spec := map[string]int{
@@ -264,7 +264,7 @@ func TestACLHostsInNetMapTable(t *testing.T) {
for name, testCase := range tests { for name, testCase := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
spec := testCase.users spec := testCase.users

View File

@@ -42,7 +42,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
baseScenario, err := NewScenario() baseScenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
scenario := AuthOIDCScenario{ scenario := AuthOIDCScenario{
@@ -100,7 +100,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
shortAccessTTL := 5 * time.Minute shortAccessTTL := 5 * time.Minute
baseScenario, err := NewScenario() baseScenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
baseScenario.pool.MaxWait = 5 * time.Minute baseScenario.pool.MaxWait = 5 * time.Minute

View File

@@ -26,7 +26,7 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
baseScenario, err := NewScenario() baseScenario, err := NewScenario(dockertestMaxWait())
if err != nil { if err != nil {
t.Fatalf("failed to create scenario: %s", err) t.Fatalf("failed to create scenario: %s", err)
} }
@@ -67,7 +67,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
baseScenario, err := NewScenario() baseScenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
scenario := AuthWebFlowScenario{ scenario := AuthWebFlowScenario{

View File

@@ -32,7 +32,7 @@ func TestUserCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()
@@ -112,7 +112,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
user := "preauthkeyspace" user := "preauthkeyspace"
count := 3 count := 3
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()
@@ -254,7 +254,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
user := "pre-auth-key-without-exp-user" user := "pre-auth-key-without-exp-user"
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()
@@ -317,7 +317,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
user := "pre-auth-key-reus-ephm-user" user := "pre-auth-key-reus-ephm-user"
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()
@@ -394,7 +394,7 @@ func TestApiKeyCommand(t *testing.T) {
count := 5 count := 5
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()
@@ -562,7 +562,7 @@ func TestNodeTagCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()
@@ -695,7 +695,7 @@ func TestNodeAdvertiseTagNoACLCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()
@@ -745,7 +745,7 @@ func TestNodeAdvertiseTagWithACLCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()
@@ -808,7 +808,7 @@ func TestNodeCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()
@@ -1049,7 +1049,7 @@ func TestNodeExpireCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()
@@ -1176,7 +1176,7 @@ func TestNodeRenameCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()
@@ -1343,7 +1343,7 @@ func TestNodeMoveCommand(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()

View File

@@ -23,7 +23,7 @@ func TestDERPServerScenario(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
// t.Parallel() // t.Parallel()
baseScenario, err := NewScenario() baseScenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
scenario := EmbeddedDERPServerScenario{ scenario := EmbeddedDERPServerScenario{

View File

@@ -23,7 +23,7 @@ func TestPingAllByIP(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()
@@ -67,7 +67,7 @@ func TestPingAllByIPPublicDERP(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()
@@ -105,7 +105,7 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()
@@ -216,7 +216,7 @@ func TestEphemeral(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()
@@ -299,7 +299,7 @@ func TestPingAllByHostname(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()
@@ -348,7 +348,7 @@ func TestTaildrop(t *testing.T) {
return err return err
} }
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()
@@ -509,7 +509,7 @@ func TestResolveMagicDNS(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()
@@ -577,7 +577,7 @@ func TestExpireNode(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()
@@ -703,7 +703,7 @@ func TestNodeOnlineStatus(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()
@@ -818,7 +818,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
IntegrationSkip(t) IntegrationSkip(t)
t.Parallel() t.Parallel()
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()

View File

@@ -73,7 +73,6 @@ database:
type: sqlite3 type: sqlite3
sqlite.path: /tmp/integration_test_db.sqlite3 sqlite.path: /tmp/integration_test_db.sqlite3
ephemeral_node_inactivity_timeout: 30m ephemeral_node_inactivity_timeout: 30m
node_update_check_interval: 10s
prefixes: prefixes:
v6: fd7a:115c:a1e0::/48 v6: fd7a:115c:a1e0::/48
v4: 100.64.0.0/10 v4: 100.64.0.0/10
@@ -116,7 +115,6 @@ func DefaultConfigEnv() map[string]string {
"HEADSCALE_DATABASE_TYPE": "sqlite", "HEADSCALE_DATABASE_TYPE": "sqlite",
"HEADSCALE_DATABASE_SQLITE_PATH": "/tmp/integration_test_db.sqlite3", "HEADSCALE_DATABASE_SQLITE_PATH": "/tmp/integration_test_db.sqlite3",
"HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "30m", "HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "30m",
"HEADSCALE_NODE_UPDATE_CHECK_INTERVAL": "10s",
"HEADSCALE_PREFIXES_V4": "100.64.0.0/10", "HEADSCALE_PREFIXES_V4": "100.64.0.0/10",
"HEADSCALE_PREFIXES_V6": "fd7a:115c:a1e0::/48", "HEADSCALE_PREFIXES_V6": "fd7a:115c:a1e0::/48",
"HEADSCALE_DNS_CONFIG_BASE_DOMAIN": "headscale.net", "HEADSCALE_DNS_CONFIG_BASE_DOMAIN": "headscale.net",

View File

@@ -18,6 +18,7 @@ import (
"net/url" "net/url"
"os" "os"
"path" "path"
"strconv"
"strings" "strings"
"time" "time"
@@ -201,6 +202,14 @@ func WithEmbeddedDERPServerOnly() Option {
} }
} }
// WithTuning allows changing the tuning settings easily.
func WithTuning(batchTimeout time.Duration, mapSessionChanSize int) Option {
return func(hsic *HeadscaleInContainer) {
hsic.env["HEADSCALE_TUNING_BATCH_CHANGE_DELAY"] = batchTimeout.String()
hsic.env["HEADSCALE_TUNING_NODE_MAPSESSION_BUFFERED_CHAN_SIZE"] = strconv.Itoa(mapSessionChanSize)
}
}
// New returns a new HeadscaleInContainer instance. // New returns a new HeadscaleInContainer instance.
func New( func New(
pool *dockertest.Pool, pool *dockertest.Pool,

View File

@@ -28,7 +28,7 @@ func TestEnablingRoutes(t *testing.T) {
user := "enable-routing" user := "enable-routing"
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErrf(t, "failed to create scenario: %s", err) assertNoErrf(t, "failed to create scenario: %s", err)
defer scenario.Shutdown() defer scenario.Shutdown()
@@ -250,7 +250,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
user := "enable-routing" user := "enable-routing"
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErrf(t, "failed to create scenario: %s", err) assertNoErrf(t, "failed to create scenario: %s", err)
// defer scenario.Shutdown() // defer scenario.Shutdown()
@@ -822,7 +822,7 @@ func TestEnableDisableAutoApprovedRoute(t *testing.T) {
user := "enable-disable-routing" user := "enable-disable-routing"
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErrf(t, "failed to create scenario: %s", err) assertNoErrf(t, "failed to create scenario: %s", err)
defer scenario.Shutdown() defer scenario.Shutdown()
@@ -966,7 +966,7 @@ func TestSubnetRouteACL(t *testing.T) {
user := "subnet-route-acl" user := "subnet-route-acl"
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErrf(t, "failed to create scenario: %s", err) assertNoErrf(t, "failed to create scenario: %s", err)
defer scenario.Shutdown() defer scenario.Shutdown()

View File

@@ -8,6 +8,7 @@ import (
"os" "os"
"sort" "sort"
"sync" "sync"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1" v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
@@ -141,7 +142,7 @@ type Scenario struct {
// NewScenario creates a test Scenario which can be used to bootstraps a ControlServer with // NewScenario creates a test Scenario which can be used to bootstraps a ControlServer with
// a set of Users and TailscaleClients. // a set of Users and TailscaleClients.
func NewScenario() (*Scenario, error) { func NewScenario(maxWait time.Duration) (*Scenario, error) {
hash, err := util.GenerateRandomStringDNSSafe(scenarioHashLength) hash, err := util.GenerateRandomStringDNSSafe(scenarioHashLength)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -152,7 +153,7 @@ func NewScenario() (*Scenario, error) {
return nil, fmt.Errorf("could not connect to docker: %w", err) return nil, fmt.Errorf("could not connect to docker: %w", err)
} }
pool.MaxWait = dockertestMaxWait() pool.MaxWait = maxWait
networkName := fmt.Sprintf("hs-%s", hash) networkName := fmt.Sprintf("hs-%s", hash)
if overrideNetworkName := os.Getenv("HEADSCALE_TEST_NETWORK_NAME"); overrideNetworkName != "" { if overrideNetworkName := os.Getenv("HEADSCALE_TEST_NETWORK_NAME"); overrideNetworkName != "" {
@@ -510,7 +511,7 @@ func (s *Scenario) GetIPs(user string) ([]netip.Addr, error) {
return ips, fmt.Errorf("failed to get ips: %w", errNoUserAvailable) return ips, fmt.Errorf("failed to get ips: %w", errNoUserAvailable)
} }
// GetIPs returns all TailscaleClients associated with a User in a Scenario. // GetClients returns all TailscaleClients associated with a User in a Scenario.
func (s *Scenario) GetClients(user string) ([]TailscaleClient, error) { func (s *Scenario) GetClients(user string) ([]TailscaleClient, error) {
var clients []TailscaleClient var clients []TailscaleClient
if ns, ok := s.users[user]; ok { if ns, ok := s.users[user]; ok {
@@ -586,7 +587,7 @@ func (s *Scenario) ListTailscaleClientsIPs(users ...string) ([]netip.Addr, error
return allIps, nil return allIps, nil
} }
// ListTailscaleClientsIPs returns a list of FQDN based on Users // ListTailscaleClientsFQDNs returns a list of FQDN based on Users
// passed as parameters. // passed as parameters.
func (s *Scenario) ListTailscaleClientsFQDNs(users ...string) ([]string, error) { func (s *Scenario) ListTailscaleClientsFQDNs(users ...string) ([]string, error) {
allFQDNs := make([]string, 0) allFQDNs := make([]string, 0)

View File

@@ -33,7 +33,7 @@ func TestHeadscale(t *testing.T) {
user := "test-space" user := "test-space"
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()
@@ -78,7 +78,7 @@ func TestCreateTailscale(t *testing.T) {
user := "only-create-containers" user := "only-create-containers"
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()
@@ -114,7 +114,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) {
count := 1 count := 1
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
defer scenario.Shutdown() defer scenario.Shutdown()

View File

@@ -44,7 +44,7 @@ var retry = func(times int, sleepInterval time.Duration,
func sshScenario(t *testing.T, policy *policy.ACLPolicy, clientsPerUser int) *Scenario { func sshScenario(t *testing.T, policy *policy.ACLPolicy, clientsPerUser int) *Scenario {
t.Helper() t.Helper()
scenario, err := NewScenario() scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err) assertNoErr(t, err)
spec := map[string]int{ spec := map[string]int{