Compare commits

..

2 Commits

Author SHA1 Message Date
Juan Font Alonso
5539ef1f8f Reduce the number of containers 2022-08-04 21:36:38 +02:00
Juan Font Alonso
100f7190f3 Temporary fix integration tests with dedicated Dockerfile 2022-07-31 10:44:09 +02:00
81 changed files with 1850 additions and 3404 deletions

View File

@@ -26,7 +26,7 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
uses: golangci/golangci-lint-action@v2
with:
version: v1.49.0
version: v1.46.1
# Only block PRs on new problems.
# If this is not enabled, we will end up having PRs
@@ -70,7 +70,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: bufbuild/buf-setup-action@v1.7.0
- uses: bufbuild/buf-setup-action@v0.7.0
- uses: bufbuild/buf-lint-action@v1
with:
input: "proto"

View File

@@ -18,7 +18,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v2
with:
go-version: 1.19.0
go-version: 1.18.0
- name: Install dependencies
run: |

View File

@@ -11,11 +11,6 @@ jobs:
with:
fetch-depth: 2
- name: Set Swap Space
uses: pierotofy/set-swap-space@master
with:
swap-size-gb: 10
- name: Get changed files
id: changed-files
uses: tj-actions/changed-files@v14.1
@@ -30,29 +25,11 @@ jobs:
- uses: cachix/install-nix-action@v16
if: steps.changed-files.outputs.any_changed == 'true'
- name: Run CLI integration tests
- name: Run Integration tests
if: steps.changed-files.outputs.any_changed == 'true'
uses: nick-fields/retry@v2
with:
timeout_minutes: 240
max_attempts: 5
retry_on: error
command: nix develop --command -- make test_integration_cli
- name: Run Embedded DERP server integration tests
if: steps.changed-files.outputs.any_changed == 'true'
uses: nick-fields/retry@v2
with:
timeout_minutes: 240
max_attempts: 5
retry_on: error
command: nix develop --command -- make test_integration_derp
- name: Run general integration tests
if: steps.changed-files.outputs.any_changed == 'true'
uses: nick-fields/retry@v2
with:
timeout_minutes: 240
max_attempts: 5
retry_on: error
command: nix develop --command -- make test_integration_general
command: nix develop --command -- make test_integration

View File

@@ -1,7 +1,7 @@
---
before:
hooks:
- go mod tidy -compat=1.19
- go mod tidy -compat=1.18
release:
prerelease: auto

View File

@@ -1,48 +1,6 @@
# CHANGELOG
## 0.17.0 (2022-XX-XX)
### BREAKING
- Log level option `log_level` was moved to a distinct `log` config section and renamed to `level` [#768](https://github.com/juanfont/headscale/pull/768)
### Changes
- Added support for Tailscale TS2021 protocol [#738](https://github.com/juanfont/headscale/pull/738)
- Add ability to specify config location via env var `HEADSCALE_CONFIG` [#674](https://github.com/juanfont/headscale/issues/674)
- Target Go 1.19 for Headscale [#778](https://github.com/juanfont/headscale/pull/778)
- Target Tailscale v1.30.0 to build Headscale [#780](https://github.com/juanfont/headscale/pull/780)
- Give a warning when running Headscale with reverse proxy improperly configured for WebSockets [#788](https://github.com/juanfont/headscale/pull/788)
- Fix subnet routers with Primary Routes [#811](https://github.com/juanfont/headscale/pull/811)
- Added support for JSON logs [#653](https://github.com/juanfont/headscale/issues/653)
## 0.16.4 (2022-08-21)
### Changes
- Add ability to connect to PostgreSQL over TLS/SSL [#745](https://github.com/juanfont/headscale/pull/745)
- Fix CLI registration of expired machines [#754](https://github.com/juanfont/headscale/pull/754)
## 0.16.3 (2022-08-17)
### Changes
- Fix issue with OIDC authentication [#747](https://github.com/juanfont/headscale/pull/747)
## 0.16.2 (2022-08-14)
### Changes
- Fixed bugs in the client registration process after migration to NodeKey [#735](https://github.com/juanfont/headscale/pull/735)
## 0.16.1 (2022-08-12)
### Changes
- Updated dependencies (including the library that lacked armhf support) [#722](https://github.com/juanfont/headscale/pull/722)
- Fix missing group expansion in function `excludeCorretlyTaggedNodes` [#563](https://github.com/juanfont/headscale/issues/563)
- Improve registration protocol implementation and switch to NodeKey as main identifier [#725](https://github.com/juanfont/headscale/pull/725)
- Add ability to connect to PostgreSQL via unix socket [#734](https://github.com/juanfont/headscale/pull/734)
## 0.17.0 (2022-xx-xx)
## 0.16.0 (2022-07-25)
@@ -152,7 +110,7 @@ This is a part of aligning `headscale`'s behaviour with Tailscale's upstream beh
- OpenID Connect users will be mapped per namespaces
- Each user will get its own namespace, created if it does not exist
- `oidc.domain_map` option has been removed
- `strip_email_domain` option has been added (see [config-example.yaml](./config-example.yaml))
- `strip_email_domain` option has been added (see [config-example.yaml](./config_example.yaml))
### Changes

View File

@@ -1,134 +0,0 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation
in our community a harassment-free experience for everyone, regardless
of age, body size, visible or invisible disability, ethnicity, sex
characteristics, gender identity and expression, level of experience,
education, socio-economic status, nationality, personal appearance,
race, religion, or sexual identity and orientation.
We pledge to act and interact in ways that contribute to an open,
welcoming, diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for
our community include:
- Demonstrating empathy and kindness toward other people
- Being respectful of differing opinions, viewpoints, and experiences
- Giving and gracefully accepting constructive feedback
- Accepting responsibility and apologizing to those affected by our
mistakes, and learning from the experience
- Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
- The use of sexualized language or imagery, and sexual attention or
advances of any kind
- Trolling, insulting or derogatory comments, and personal or
political attacks
- Public or private harassment
- Publishing others' private information, such as a physical or email
address, without their explicit permission
- Other conduct which could reasonably be considered inappropriate in
a professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our
standards of acceptable behavior and will take appropriate and fair
corrective action in response to any behavior that they deem
inappropriate, threatening, offensive, or harmful.
Community leaders have the right and responsibility to remove, edit,
or reject comments, commits, code, wiki edits, issues, and other
contributions that are not aligned to this Code of Conduct, and will
communicate reasons for moderation decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also
applies when an individual is officially representing the community in
public spaces. Examples of representing our community include using an
official e-mail address, posting via an official social media account,
or acting as an appointed representative at an online or offline
event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior
may be reported to the community leaders responsible for enforcement
at our Discord channel. All complaints
will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and
security of the reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in
determining the consequences for any action they deem in violation of
this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior
deemed unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders,
providing clarity around the nature of the violation and an
explanation of why the behavior was inappropriate. A public apology
may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued
behavior. No interaction with the people involved, including
unsolicited interaction with those enforcing the Code of Conduct, for
a specified period of time. This includes avoiding interactions in
community spaces as well as external channels like social
media. Violating these terms may lead to a temporary or permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards,
including sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or
public communication with the community for a specified period of
time. No public or private interaction with the people involved,
including unsolicited interaction with those enforcing the Code of
Conduct, is allowed during this period. Violating these terms may lead
to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of
community standards, including sustained inappropriate behavior,
harassment of an individual, or aggression toward or disparagement of
classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction
within the community.
## Attribution
This Code of Conduct is adapted from the [Contributor
Covenant][homepage], version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of
conduct enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the
FAQ at https://www.contributor-covenant.org/faq. Translations are
available at https://www.contributor-covenant.org/translations.

View File

@@ -1,5 +1,5 @@
# Builder image
FROM docker.io/golang:1.19.0-bullseye AS build
FROM --platform=$BUILDPLATFORM docker.io/golang:1.18.0-bullseye AS build
ARG VERSION=dev
ENV GOPATH /go
WORKDIR /go/src/headscale
@@ -8,9 +8,8 @@ COPY go.mod go.sum /go/src/headscale/
RUN go mod download
COPY . .
RUN CGO_ENABLED=0 GOOS=linux go install -ldflags="-s -w -X github.com/juanfont/headscale/cmd/headscale/cli.Version=$VERSION" -a ./cmd/headscale
RUN strip /go/bin/headscale
ARG TARGETOS TARGETARCH
RUN CGO_ENABLED=0 GOOS=$TARGETOS GOARCH=$TARGETARCH go build -o /go/bin/headscale -ldflags="-s -w -X github.com/juanfont/headscale/cmd/headscale/cli.Version=$VERSION" -a ./cmd/headscale
RUN test -e /go/bin/headscale
# Production image

View File

@@ -1,5 +1,5 @@
# Builder image
FROM docker.io/golang:1.19.0-alpine AS build
FROM --platform=$BUILDPLATFORM docker.io/golang:1.18.0-alpine AS build
ARG VERSION=dev
ENV GOPATH /go
WORKDIR /go/src/headscale
@@ -10,8 +10,8 @@ RUN go mod download
COPY . .
RUN CGO_ENABLED=0 GOOS=linux go install -ldflags="-s -w -X github.com/juanfont/headscale/cmd/headscale/cli.Version=$VERSION" -a ./cmd/headscale
RUN strip /go/bin/headscale
ARG TARGETOS TARGETARCH
RUN CGO_ENABLED=0 GOOS=$TARGETOS GOARCH=$TARGETARCH go build -o /go/bin/headscale -ldflags="-s -w -X github.com/juanfont/headscale/cmd/headscale/cli.Version=$VERSION" -a ./cmd/headscale
RUN test -e /go/bin/headscale
# Production image

View File

@@ -1,5 +1,5 @@
# Builder image
FROM docker.io/golang:1.19.0-bullseye AS build
FROM --platform=$BUILDPLATFORM docker.io/golang:1.18.0-bullseye AS build
ARG VERSION=dev
ENV GOPATH /go
WORKDIR /go/src/headscale
@@ -9,6 +9,7 @@ RUN go mod download
COPY . .
ARG TARGETOS TARGETARCH
RUN CGO_ENABLED=0 GOOS=linux go install -ldflags="-s -w -X github.com/juanfont/headscale/cmd/headscale/cli.Version=$VERSION" -a ./cmd/headscale
RUN test -e /go/bin/headscale

View File

@@ -7,9 +7,7 @@ RUN apt-get update \
RUN git clone https://github.com/tailscale/tailscale.git
WORKDIR /go/tailscale
RUN git checkout main
WORKDIR tailscale
RUN sh build_dist.sh tailscale.com/cmd/tailscale
RUN sh build_dist.sh tailscale.com/cmd/tailscaled

View File

@@ -0,0 +1,21 @@
# Builder image
FROM docker.io/golang:1.18.0-bullseye AS build
ARG VERSION=dev
ENV GOPATH /go
WORKDIR /go/src/headscale
COPY go.mod go.sum /go/src/headscale/
RUN go mod download
COPY . .
RUN CGO_ENABLED=0 go build -o /go/bin/headscale -ldflags="-s -w -X github.com/juanfont/headscale/cmd/headscale/cli.Version=$VERSION" -a ./cmd/headscale
RUN test -e /go/bin/headscale
# Production image
FROM gcr.io/distroless/base-debian11
COPY --from=build /go/bin/headscale /bin/headscale
ENV TZ UTC
EXPOSE 8080/tcp
CMD ["headscale"]

View File

@@ -24,16 +24,14 @@ dev: lint test build
test:
@go test -coverprofile=coverage.out ./...
test_integration: test_integration_cli test_integration_derp test_integration_general
test_integration:
go test -failfast -tags integration -timeout 30m -count=1 ./...
test_integration_cli:
go test -failfast -tags integration_cli,integration -timeout 30m -count=1 ./...
go test -tags integration -v integration_cli_test.go integration_common_test.go
test_integration_derp:
go test -failfast -tags integration_derp,integration -timeout 30m -count=1 ./...
test_integration_general:
go test -failfast -tags integration_general,integration -timeout 30m -count=1 ./...
go test -tags integration -v integration_embedded_derp_test.go integration_common_test.go
coverprofile_func:
go tool cover -func=coverage.out

122
README.md
View File

@@ -1,4 +1,4 @@
![headscale logo](./docs/logo/headscale3_header_stacked_left.png)
# headscale
![ci](https://github.com/juanfont/headscale/actions/workflows/test.yml/badge.svg)
@@ -67,15 +67,15 @@ one of the maintainers.
## Client OS support
| OS | Supports headscale |
| ------- | --------------------------------------------------------- |
| Linux | Yes |
| OpenBSD | Yes |
| FreeBSD | Yes |
| macOS | Yes (see `/apple` on your headscale for more information) |
| Windows | Yes [docs](./docs/windows-client.md) |
| Android | Yes [docs](./docs/android-client.md) |
| iOS | Not yet |
| OS | Supports headscale |
| ------- | ----------------------------------------------------------------------------------------------------------------- |
| Linux | Yes |
| OpenBSD | Yes |
| FreeBSD | Yes |
| macOS | Yes (see `/apple` on your headscale for more information) |
| Windows | Yes [docs](./docs/windows-client.md) |
| Android | [You need to compile the client yourself](https://github.com/juanfont/headscale/issues/58#issuecomment-885255270) |
| iOS | Not yet |
## Running headscale
@@ -232,13 +232,6 @@ make build
<sub style="font-size:14px"><b>unreality</b></sub>
</a>
</td>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/ohdearaugustin>
<img src=https://avatars.githubusercontent.com/u/14001491?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=ohdearaugustin/>
<br />
<sub style="font-size:14px"><b>ohdearaugustin</b></sub>
</a>
</td>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/mpldr>
<img src=https://avatars.githubusercontent.com/u/33086936?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Moritz Poldrack/>
@@ -246,15 +239,15 @@ make build
<sub style="font-size:14px"><b>Moritz Poldrack</b></sub>
</a>
</td>
</tr>
<tr>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/GrigoriyMikhalkin>
<img src=https://avatars.githubusercontent.com/u/3637857?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=GrigoriyMikhalkin/>
<a href=https://github.com/ohdearaugustin>
<img src=https://avatars.githubusercontent.com/u/14001491?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=ohdearaugustin/>
<br />
<sub style="font-size:14px"><b>GrigoriyMikhalkin</b></sub>
<sub style="font-size:14px"><b>ohdearaugustin</b></sub>
</a>
</td>
</tr>
<tr>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/Niek>
<img src=https://avatars.githubusercontent.com/u/213140?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Niek van der Maas/>
@@ -270,10 +263,10 @@ make build
</a>
</td>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/617a7a>
<img src=https://avatars.githubusercontent.com/u/67651251?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Azz/>
<a href=https://github.com/qbit>
<img src=https://avatars.githubusercontent.com/u/68368?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Aaron Bieber/>
<br />
<sub style="font-size:14px"><b>Azz</b></sub>
<sub style="font-size:14px"><b>Aaron Bieber</b></sub>
</a>
</td>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
@@ -283,22 +276,6 @@ make build
<sub style="font-size:14px"><b>Anton Schubert</b></sub>
</a>
</td>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/qbit>
<img src=https://avatars.githubusercontent.com/u/68368?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Aaron Bieber/>
<br />
<sub style="font-size:14px"><b>Aaron Bieber</b></sub>
</a>
</td>
</tr>
<tr>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/Aluxima>
<img src=https://avatars.githubusercontent.com/u/16262531?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Laurent Marchaud/>
<br />
<sub style="font-size:14px"><b>Laurent Marchaud</b></sub>
</a>
</td>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/fdelucchijr>
<img src=https://avatars.githubusercontent.com/u/69133647?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Fernando De Lucchi/>
@@ -306,6 +283,15 @@ make build
<sub style="font-size:14px"><b>Fernando De Lucchi</b></sub>
</a>
</td>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/GrigoriyMikhalkin>
<img src=https://avatars.githubusercontent.com/u/3637857?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=GrigoriyMikhalkin/>
<br />
<sub style="font-size:14px"><b>GrigoriyMikhalkin</b></sub>
</a>
</td>
</tr>
<tr>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/hdhoang>
<img src=https://avatars.githubusercontent.com/u/12537?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Hoàng Đức Hiếu/>
@@ -334,8 +320,6 @@ make build
<sub style="font-size:14px"><b>ChibangLW</b></sub>
</a>
</td>
</tr>
<tr>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/mevansam>
<img src=https://avatars.githubusercontent.com/u/403630?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Mevan Samaratunga/>
@@ -350,6 +334,8 @@ make build
<sub style="font-size:14px"><b>Michael G.</b></sub>
</a>
</td>
</tr>
<tr>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/ptman>
<img src=https://avatars.githubusercontent.com/u/24669?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Paul Tötterman/>
@@ -378,8 +364,6 @@ make build
<sub style="font-size:14px"><b>Artem Klevtsov</b></sub>
</a>
</td>
</tr>
<tr>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/cmars>
<img src=https://avatars.githubusercontent.com/u/23741?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Casey Marshall/>
@@ -394,6 +378,8 @@ make build
<sub style="font-size:14px"><b>Pavlos Vinieratos</b></sub>
</a>
</td>
</tr>
<tr>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/SilverBut>
<img src=https://avatars.githubusercontent.com/u/6560655?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Silver Bullet/>
@@ -401,13 +387,6 @@ make build
<sub style="font-size:14px"><b>Silver Bullet</b></sub>
</a>
</td>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/vtrf>
<img src=https://avatars.githubusercontent.com/u/25647735?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Victor Freire/>
<br />
<sub style="font-size:14px"><b>Victor Freire</b></sub>
</a>
</td>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/lachy2849>
<img src=https://avatars.githubusercontent.com/u/98844035?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=lachy2849/>
@@ -422,8 +401,6 @@ make build
<sub style="font-size:14px"><b>thomas</b></sub>
</a>
</td>
</tr>
<tr>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/aberoham>
<img src=https://avatars.githubusercontent.com/u/586805?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Abraham Ingersoll/>
@@ -445,6 +422,8 @@ make build
<sub style="font-size:14px"><b>Aofei Sheng</b></sub>
</a>
</td>
</tr>
<tr>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/awoimbee>
<img src=https://avatars.githubusercontent.com/u/22431493?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Arthur Woimbée/>
@@ -466,8 +445,6 @@ make build
<sub style="font-size:14px"><b> Carson Yang</b></sub>
</a>
</td>
</tr>
<tr>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/kundel>
<img src=https://avatars.githubusercontent.com/u/10158899?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=kundel/>
@@ -489,6 +466,8 @@ make build
<sub style="font-size:14px"><b>Felix Yan</b></sub>
</a>
</td>
</tr>
<tr>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/JJGadgets>
<img src=https://avatars.githubusercontent.com/u/5709019?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=JJGadgets/>
@@ -510,8 +489,6 @@ make build
<sub style="font-size:14px"><b>Jim Tittsler</b></sub>
</a>
</td>
</tr>
<tr>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/piec>
<img src=https://avatars.githubusercontent.com/u/781471?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Pierre Carru/>
@@ -519,13 +496,6 @@ make build
<sub style="font-size:14px"><b>Pierre Carru</b></sub>
</a>
</td>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/nnsee>
<img src=https://avatars.githubusercontent.com/u/36747857?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Rasmus Moorats/>
<br />
<sub style="font-size:14px"><b>Rasmus Moorats</b></sub>
</a>
</td>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/rcursaru>
<img src=https://avatars.githubusercontent.com/u/16259641?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=rcursaru/>
@@ -540,6 +510,8 @@ make build
<sub style="font-size:14px"><b>WhiteSource Renovate</b></sub>
</a>
</td>
</tr>
<tr>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/ryanfowler>
<img src=https://avatars.githubusercontent.com/u/2668821?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Ryan Fowler/>
@@ -554,15 +526,6 @@ make build
<sub style="font-size:14px"><b>Shaanan Cohney</b></sub>
</a>
</td>
</tr>
<tr>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/sophware>
<img src=https://avatars.githubusercontent.com/u/41669?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=sophware/>
<br />
<sub style="font-size:14px"><b>sophware</b></sub>
</a>
</td>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/m-tanner-dev0>
<img src=https://avatars.githubusercontent.com/u/97977342?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Tanner/>
@@ -591,6 +554,8 @@ make build
<sub style="font-size:14px"><b>Tianon Gravi</b></sub>
</a>
</td>
</tr>
<tr>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/woudsma>
<img src=https://avatars.githubusercontent.com/u/6162978?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Tjerk Woudsma/>
@@ -598,8 +563,6 @@ make build
<sub style="font-size:14px"><b>Tjerk Woudsma</b></sub>
</a>
</td>
</tr>
<tr>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/y0ngb1n>
<img src=https://avatars.githubusercontent.com/u/25719408?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Yang Bin/>
@@ -607,13 +570,6 @@ make build
<sub style="font-size:14px"><b>Yang Bin</b></sub>
</a>
</td>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/gozssky>
<img src=https://avatars.githubusercontent.com/u/17199941?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Yujie Xia/>
<br />
<sub style="font-size:14px"><b>Yujie Xia</b></sub>
</a>
</td>
<td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0">
<a href=https://github.com/zekker6>
<img src=https://avatars.githubusercontent.com/u/1367798?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Zakhar Bessarab/>

31
acls.go
View File

@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"io"
"net/netip"
"os"
"path/filepath"
"strconv"
@@ -14,6 +13,7 @@ import (
"github.com/rs/zerolog/log"
"github.com/tailscale/hujson"
"gopkg.in/yaml.v3"
"inet.af/netaddr"
"tailscale.com/tailcfg"
)
@@ -162,12 +162,7 @@ func (h *Headscale) generateACLRules() ([]tailcfg.FilterRule, error) {
destPorts := []tailcfg.NetPortRange{}
for innerIndex, dest := range acl.Destinations {
dests, err := h.generateACLPolicyDest(
machines,
*h.aclPolicy,
dest,
needsWildcard,
)
dests, err := h.generateACLPolicyDest(machines, *h.aclPolicy, dest, needsWildcard)
if err != nil {
log.Error().
Msgf("Error parsing ACL %d, Destination %d", index, innerIndex)
@@ -260,12 +255,7 @@ func (h *Headscale) generateACLPolicyDest(
func parseProtocol(protocol string) ([]int, bool, error) {
switch protocol {
case "":
return []int{
protocolICMP,
protocolIPv6ICMP,
protocolTCP,
protocolUDP,
}, false, nil
return []int{protocolICMP, protocolIPv6ICMP, protocolTCP, protocolUDP}, false, nil
case "igmp":
return []int{protocolIGMP}, true, nil
case "ipv4", "ip-in-ip":
@@ -294,9 +284,7 @@ func parseProtocol(protocol string) ([]int, bool, error) {
if err != nil {
return nil, false, err
}
needsWildcard := protocolNumber != protocolTCP &&
protocolNumber != protocolUDP &&
protocolNumber != protocolSCTP
needsWildcard := protocolNumber != protocolTCP && protocolNumber != protocolUDP && protocolNumber != protocolSCTP
return []int{protocolNumber}, needsWildcard, nil
}
@@ -379,7 +367,7 @@ func expandAlias(
// if alias is a namespace
nodes := filterMachinesByNamespace(machines, alias)
nodes = excludeCorrectlyTaggedNodes(aclPolicy, nodes, alias, stripEmailDomain)
nodes = excludeCorrectlyTaggedNodes(aclPolicy, nodes, alias)
for _, n := range nodes {
ips = append(ips, n.IPAddresses.ToStringSlice()...)
@@ -394,13 +382,13 @@ func expandAlias(
}
// if alias is an IP
ip, err := netip.ParseAddr(alias)
ip, err := netaddr.ParseIP(alias)
if err == nil {
return []string{ip.String()}, nil
}
// if alias is an CIDR
cidr, err := netip.ParsePrefix(alias)
cidr, err := netaddr.ParseIPPrefix(alias)
if err == nil {
return []string{cidr.String()}, nil
}
@@ -417,13 +405,10 @@ func excludeCorrectlyTaggedNodes(
aclPolicy ACLPolicy,
nodes []Machine,
namespace string,
stripEmailDomain bool,
) []Machine {
out := []Machine{}
tags := []string{}
for tag := range aclPolicy.TagOwners {
owners, _ := expandTagOwners(aclPolicy, namespace, stripEmailDomain)
ns := append(owners, namespace)
for tag, ns := range aclPolicy.TagOwners {
if contains(ns, namespace) {
tags = append(tags, tag)
}

View File

@@ -2,11 +2,11 @@ package headscale
import (
"errors"
"net/netip"
"reflect"
"testing"
"gopkg.in/check.v1"
"inet.af/netaddr"
"tailscale.com/tailcfg"
)
@@ -62,11 +62,7 @@ func (s *Suite) TestBasicRule(c *check.C) {
func (s *Suite) TestInvalidAction(c *check.C) {
app.aclPolicy = &ACLPolicy{
ACLs: []ACL{
{
Action: "invalidAction",
Sources: []string{"*"},
Destinations: []string{"*:*"},
},
{Action: "invalidAction", Sources: []string{"*"}, Destinations: []string{"*:*"}},
},
}
err := app.UpdateACLRules()
@@ -81,11 +77,7 @@ func (s *Suite) TestInvalidGroupInGroup(c *check.C) {
"group:error": []string{"foo", "group:test"},
},
ACLs: []ACL{
{
Action: "accept",
Sources: []string{"group:error"},
Destinations: []string{"*:*"},
},
{Action: "accept", Sources: []string{"group:error"}, Destinations: []string{"*:*"}},
},
}
err := app.UpdateACLRules()
@@ -96,11 +88,7 @@ func (s *Suite) TestInvalidTagOwners(c *check.C) {
// this ACL is wrong because no tagOwners own the requested tag for the server
app.aclPolicy = &ACLPolicy{
ACLs: []ACL{
{
Action: "accept",
Sources: []string{"tag:foo"},
Destinations: []string{"*:*"},
},
{Action: "accept", Sources: []string{"tag:foo"}, Destinations: []string{"*:*"}},
},
}
err := app.UpdateACLRules()
@@ -131,7 +119,7 @@ func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) {
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.1")},
IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")},
NamespaceID: namespace.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
@@ -143,11 +131,7 @@ func (s *Suite) TestValidExpandTagOwnersInSources(c *check.C) {
Groups: Groups{"group:test": []string{"user1", "user2"}},
TagOwners: TagOwners{"tag:test": []string{"user3", "group:test"}},
ACLs: []ACL{
{
Action: "accept",
Sources: []string{"tag:test"},
Destinations: []string{"*:*"},
},
{Action: "accept", Sources: []string{"tag:test"}, Destinations: []string{"*:*"}},
},
}
err = app.UpdateACLRules()
@@ -181,7 +165,7 @@ func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) {
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.1")},
IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")},
NamespaceID: namespace.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
@@ -193,11 +177,7 @@ func (s *Suite) TestValidExpandTagOwnersInDestinations(c *check.C) {
Groups: Groups{"group:test": []string{"user1", "user2"}},
TagOwners: TagOwners{"tag:test": []string{"user3", "group:test"}},
ACLs: []ACL{
{
Action: "accept",
Sources: []string{"*"},
Destinations: []string{"tag:test:*"},
},
{Action: "accept", Sources: []string{"*"}, Destinations: []string{"tag:test:*"}},
},
}
err = app.UpdateACLRules()
@@ -231,7 +211,7 @@ func (s *Suite) TestInvalidTagValidNamespace(c *check.C) {
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.1")},
IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")},
NamespaceID: namespace.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
@@ -242,11 +222,7 @@ func (s *Suite) TestInvalidTagValidNamespace(c *check.C) {
app.aclPolicy = &ACLPolicy{
TagOwners: TagOwners{"tag:test": []string{"user1"}},
ACLs: []ACL{
{
Action: "accept",
Sources: []string{"user1"},
Destinations: []string{"*:*"},
},
{Action: "accept", Sources: []string{"user1"}, Destinations: []string{"*:*"}},
},
}
err = app.UpdateACLRules()
@@ -280,7 +256,7 @@ func (s *Suite) TestValidTagInvalidNamespace(c *check.C) {
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "webserver",
IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.1")},
IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")},
NamespaceID: namespace.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
@@ -299,7 +275,7 @@ func (s *Suite) TestValidTagInvalidNamespace(c *check.C) {
NodeKey: "bar2",
DiscoKey: "faab",
Hostname: "user",
IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.2")},
IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")},
NamespaceID: namespace.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
@@ -825,6 +801,7 @@ func Test_listMachinesInNamespace(t *testing.T) {
}
}
// nolint
func Test_expandAlias(t *testing.T) {
type args struct {
machines []Machine
@@ -843,10 +820,10 @@ func Test_expandAlias(t *testing.T) {
args: args{
alias: "*",
machines: []Machine{
{IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.1")}},
{IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")}},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.78.84.227"),
netaddr.MustParseIP("100.78.84.227"),
},
},
},
@@ -863,25 +840,25 @@ func Test_expandAlias(t *testing.T) {
machines: []Machine{
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.1"),
netaddr.MustParseIP("100.64.0.1"),
},
Namespace: Namespace{Name: "joe"},
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.2"),
netaddr.MustParseIP("100.64.0.2"),
},
Namespace: Namespace{Name: "joe"},
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.3"),
netaddr.MustParseIP("100.64.0.3"),
},
Namespace: Namespace{Name: "marc"},
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.4"),
netaddr.MustParseIP("100.64.0.4"),
},
Namespace: Namespace{Name: "mickael"},
},
@@ -901,25 +878,25 @@ func Test_expandAlias(t *testing.T) {
machines: []Machine{
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.1"),
netaddr.MustParseIP("100.64.0.1"),
},
Namespace: Namespace{Name: "joe"},
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.2"),
netaddr.MustParseIP("100.64.0.2"),
},
Namespace: Namespace{Name: "joe"},
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.3"),
netaddr.MustParseIP("100.64.0.3"),
},
Namespace: Namespace{Name: "marc"},
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.4"),
netaddr.MustParseIP("100.64.0.4"),
},
Namespace: Namespace{Name: "mickael"},
},
@@ -950,7 +927,7 @@ func Test_expandAlias(t *testing.T) {
machines: []Machine{},
aclPolicy: ACLPolicy{
Hosts: Hosts{
"homeNetwork": netip.MustParsePrefix("192.168.1.0/24"),
"homeNetwork": netaddr.MustParseIPPrefix("192.168.1.0/24"),
},
},
stripEmailDomain: true,
@@ -987,7 +964,7 @@ func Test_expandAlias(t *testing.T) {
machines: []Machine{
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.1"),
netaddr.MustParseIP("100.64.0.1"),
},
Namespace: Namespace{Name: "joe"},
HostInfo: HostInfo{
@@ -998,7 +975,7 @@ func Test_expandAlias(t *testing.T) {
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.2"),
netaddr.MustParseIP("100.64.0.2"),
},
Namespace: Namespace{Name: "joe"},
HostInfo: HostInfo{
@@ -1009,13 +986,13 @@ func Test_expandAlias(t *testing.T) {
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.3"),
netaddr.MustParseIP("100.64.0.3"),
},
Namespace: Namespace{Name: "marc"},
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.4"),
netaddr.MustParseIP("100.64.0.4"),
},
Namespace: Namespace{Name: "joe"},
},
@@ -1035,25 +1012,25 @@ func Test_expandAlias(t *testing.T) {
machines: []Machine{
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.1"),
netaddr.MustParseIP("100.64.0.1"),
},
Namespace: Namespace{Name: "joe"},
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.2"),
netaddr.MustParseIP("100.64.0.2"),
},
Namespace: Namespace{Name: "joe"},
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.3"),
netaddr.MustParseIP("100.64.0.3"),
},
Namespace: Namespace{Name: "marc"},
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.4"),
netaddr.MustParseIP("100.64.0.4"),
},
Namespace: Namespace{Name: "mickael"},
},
@@ -1076,27 +1053,27 @@ func Test_expandAlias(t *testing.T) {
machines: []Machine{
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.1"),
netaddr.MustParseIP("100.64.0.1"),
},
Namespace: Namespace{Name: "joe"},
ForcedTags: []string{"tag:hr-webserver"},
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.2"),
netaddr.MustParseIP("100.64.0.2"),
},
Namespace: Namespace{Name: "joe"},
ForcedTags: []string{"tag:hr-webserver"},
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.3"),
netaddr.MustParseIP("100.64.0.3"),
},
Namespace: Namespace{Name: "marc"},
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.4"),
netaddr.MustParseIP("100.64.0.4"),
},
Namespace: Namespace{Name: "mickael"},
},
@@ -1114,14 +1091,14 @@ func Test_expandAlias(t *testing.T) {
machines: []Machine{
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.1"),
netaddr.MustParseIP("100.64.0.1"),
},
Namespace: Namespace{Name: "joe"},
ForcedTags: []string{"tag:hr-webserver"},
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.2"),
netaddr.MustParseIP("100.64.0.2"),
},
Namespace: Namespace{Name: "joe"},
HostInfo: HostInfo{
@@ -1132,13 +1109,13 @@ func Test_expandAlias(t *testing.T) {
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.3"),
netaddr.MustParseIP("100.64.0.3"),
},
Namespace: Namespace{Name: "marc"},
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.4"),
netaddr.MustParseIP("100.64.0.4"),
},
Namespace: Namespace{Name: "mickael"},
},
@@ -1160,7 +1137,7 @@ func Test_expandAlias(t *testing.T) {
machines: []Machine{
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.1"),
netaddr.MustParseIP("100.64.0.1"),
},
Namespace: Namespace{Name: "joe"},
HostInfo: HostInfo{
@@ -1171,7 +1148,7 @@ func Test_expandAlias(t *testing.T) {
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.2"),
netaddr.MustParseIP("100.64.0.2"),
},
Namespace: Namespace{Name: "joe"},
HostInfo: HostInfo{
@@ -1182,13 +1159,13 @@ func Test_expandAlias(t *testing.T) {
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.3"),
netaddr.MustParseIP("100.64.0.3"),
},
Namespace: Namespace{Name: "marc"},
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.4"),
netaddr.MustParseIP("100.64.0.4"),
},
Namespace: Namespace{Name: "joe"},
},
@@ -1224,10 +1201,9 @@ func Test_expandAlias(t *testing.T) {
func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
type args struct {
aclPolicy ACLPolicy
nodes []Machine
namespace string
stripEmailDomain bool
aclPolicy ACLPolicy
nodes []Machine
namespace string
}
tests := []struct {
name string
@@ -1244,7 +1220,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
nodes: []Machine{
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.1"),
netaddr.MustParseIP("100.64.0.1"),
},
Namespace: Namespace{Name: "joe"},
HostInfo: HostInfo{
@@ -1255,7 +1231,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.2"),
netaddr.MustParseIP("100.64.0.2"),
},
Namespace: Namespace{Name: "joe"},
HostInfo: HostInfo{
@@ -1266,68 +1242,16 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.4"),
netaddr.MustParseIP("100.64.0.4"),
},
Namespace: Namespace{Name: "joe"},
},
},
namespace: "joe",
stripEmailDomain: true,
namespace: "joe",
},
want: []Machine{
{
IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.4")},
Namespace: Namespace{Name: "joe"},
},
},
},
{
name: "exclude nodes with valid tags, and owner is in a group",
args: args{
aclPolicy: ACLPolicy{
Groups: Groups{
"group:accountant": []string{"joe", "bar"},
},
TagOwners: TagOwners{
"tag:accountant-webserver": []string{"group:accountant"},
},
},
nodes: []Machine{
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.1"),
},
Namespace: Namespace{Name: "joe"},
HostInfo: HostInfo{
OS: "centos",
Hostname: "foo",
RequestTags: []string{"tag:accountant-webserver"},
},
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.2"),
},
Namespace: Namespace{Name: "joe"},
HostInfo: HostInfo{
OS: "centos",
Hostname: "foo",
RequestTags: []string{"tag:accountant-webserver"},
},
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.4"),
},
Namespace: Namespace{Name: "joe"},
},
},
namespace: "joe",
stripEmailDomain: true,
},
want: []Machine{
{
IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.4")},
IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")},
Namespace: Namespace{Name: "joe"},
},
},
@@ -1341,7 +1265,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
nodes: []Machine{
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.1"),
netaddr.MustParseIP("100.64.0.1"),
},
Namespace: Namespace{Name: "joe"},
HostInfo: HostInfo{
@@ -1352,24 +1276,23 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.2"),
netaddr.MustParseIP("100.64.0.2"),
},
Namespace: Namespace{Name: "joe"},
ForcedTags: []string{"tag:accountant-webserver"},
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.4"),
netaddr.MustParseIP("100.64.0.4"),
},
Namespace: Namespace{Name: "joe"},
},
},
namespace: "joe",
stripEmailDomain: true,
namespace: "joe",
},
want: []Machine{
{
IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.4")},
IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.4")},
Namespace: Namespace{Name: "joe"},
},
},
@@ -1383,7 +1306,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
nodes: []Machine{
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.1"),
netaddr.MustParseIP("100.64.0.1"),
},
Namespace: Namespace{Name: "joe"},
HostInfo: HostInfo{
@@ -1394,7 +1317,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.2"),
netaddr.MustParseIP("100.64.0.2"),
},
Namespace: Namespace{Name: "joe"},
HostInfo: HostInfo{
@@ -1405,18 +1328,17 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.4"),
netaddr.MustParseIP("100.64.0.4"),
},
Namespace: Namespace{Name: "joe"},
},
},
namespace: "joe",
stripEmailDomain: true,
namespace: "joe",
},
want: []Machine{
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.1"),
netaddr.MustParseIP("100.64.0.1"),
},
Namespace: Namespace{Name: "joe"},
HostInfo: HostInfo{
@@ -1427,7 +1349,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.2"),
netaddr.MustParseIP("100.64.0.2"),
},
Namespace: Namespace{Name: "joe"},
HostInfo: HostInfo{
@@ -1438,7 +1360,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
},
{
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.4"),
netaddr.MustParseIP("100.64.0.4"),
},
Namespace: Namespace{Name: "joe"},
},
@@ -1451,7 +1373,6 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
test.args.aclPolicy,
test.args.nodes,
test.args.namespace,
test.args.stripEmailDomain,
)
if !reflect.DeepEqual(got, test.want) {
t.Errorf("excludeCorrectlyTaggedNodes() = %v, want %v", got, test.want)

View File

@@ -2,11 +2,11 @@ package headscale
import (
"encoding/json"
"net/netip"
"strings"
"github.com/tailscale/hujson"
"gopkg.in/yaml.v3"
"inet.af/netaddr"
)
// ACLPolicy represents a Tailscale ACL Policy.
@@ -21,28 +21,28 @@ type ACLPolicy struct {
// ACL is a basic rule for the ACL Policy.
type ACL struct {
Action string `json:"action" yaml:"action"`
Protocol string `json:"proto" yaml:"proto"`
Sources []string `json:"src" yaml:"src"`
Destinations []string `json:"dst" yaml:"dst"`
Protocol string `json:"proto" yaml:"proto"`
Sources []string `json:"src" yaml:"src"`
Destinations []string `json:"dst" yaml:"dst"`
}
// Groups references a series of alias in the ACL rules.
type Groups map[string][]string
// Hosts are alias for IP addresses or subnets.
type Hosts map[string]netip.Prefix
type Hosts map[string]netaddr.IPPrefix
// TagOwners specify what users (namespaces?) are allow to use certain tags.
type TagOwners map[string][]string
// ACLTest is not implemented, but should be use to check if a certain rule is allowed.
type ACLTest struct {
Source string `json:"src" yaml:"src"`
Accept []string `json:"accept" yaml:"accept"`
Source string `json:"src" yaml:"src"`
Accept []string `json:"accept" yaml:"accept"`
Deny []string `json:"deny,omitempty" yaml:"deny,omitempty"`
}
// UnmarshalJSON allows to parse the Hosts directly into netip objects.
// UnmarshalJSON allows to parse the Hosts directly into netaddr objects.
func (hosts *Hosts) UnmarshalJSON(data []byte) error {
newHosts := Hosts{}
hostIPPrefixMap := make(map[string]string)
@@ -60,7 +60,7 @@ func (hosts *Hosts) UnmarshalJSON(data []byte) error {
if !strings.Contains(prefixStr, "/") {
prefixStr += "/32"
}
prefix, err := netip.ParsePrefix(prefixStr)
prefix, err := netaddr.ParseIPPrefix(prefixStr)
if err != nil {
return err
}
@@ -71,7 +71,7 @@ func (hosts *Hosts) UnmarshalJSON(data []byte) error {
return nil
}
// UnmarshalYAML allows to parse the Hosts directly into netip objects.
// UnmarshalYAML allows to parse the Hosts directly into netaddr objects.
func (hosts *Hosts) UnmarshalYAML(data []byte) error {
newHosts := Hosts{}
hostIPPrefixMap := make(map[string]string)
@@ -81,7 +81,7 @@ func (hosts *Hosts) UnmarshalYAML(data []byte) error {
return err
}
for host, prefixStr := range hostIPPrefixMap {
prefix, err := netip.ParsePrefix(prefixStr)
prefix, err := netaddr.ParseIPPrefix(prefixStr)
if err != nil {
return err
}

747
api.go
View File

@@ -2,18 +2,25 @@ package headscale
import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"html/template"
"io"
"net/http"
"strings"
"time"
"github.com/gorilla/mux"
"github.com/klauspost/compress/zstd"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
const (
// TODO(juan): remove this once https://github.com/juanfont/headscale/issues/727 is fixed.
registrationHoldoff = time.Second * 5
reservedResponseHeaderSize = 4
RegisterMethodAuthKey = "authkey"
RegisterMethodOIDC = "oidc"
@@ -52,7 +59,7 @@ func (h *Headscale) HealthHandler(
}
}
if err := h.pingDB(req.Context()); err != nil {
if err := h.pingDB(); err != nil {
respond(err)
return
@@ -61,6 +68,23 @@ func (h *Headscale) HealthHandler(
respond(nil)
}
// KeyHandler provides the Headscale pub key
// Listens in /key.
func (h *Headscale) KeyHandler(
writer http.ResponseWriter,
req *http.Request,
) {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err := writer.Write([]byte(MachinePublicKeyStripPrefix(h.privateKey.Public())))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
}
type registerWebAPITemplateConfig struct {
Key string
}
@@ -83,17 +107,13 @@ var registerWebAPITemplate = template.Must(
`))
// RegisterWebAPI shows a simple message in the browser to point to the CLI
// Listens in /register/:nkey.
//
// This is not part of the Tailscale control API, as we could send whatever URL
// in the RegisterResponse.AuthURL field.
// Listens in /register.
func (h *Headscale) RegisterWebAPI(
writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
nodeKeyStr, ok := vars["nkey"]
if !ok || nodeKeyStr == "" {
machineKeyStr := req.URL.Query().Get("key")
if machineKeyStr == "" {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Wrong params"))
@@ -109,7 +129,7 @@ func (h *Headscale) RegisterWebAPI(
var content bytes.Buffer
if err := registerWebAPITemplate.Execute(&content, registerWebAPITemplateConfig{
Key: nodeKeyStr,
Key: machineKeyStr,
}); err != nil {
log.Error().
Str("func", "RegisterWebAPI").
@@ -138,3 +158,708 @@ func (h *Headscale) RegisterWebAPI(
Msg("Failed to write response")
}
}
// RegistrationHandler handles the actual registration process of a machine
// Endpoint /machine/:mkey.
func (h *Headscale) RegistrationHandler(
writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
machineKeyStr, ok := vars["mkey"]
if !ok || machineKeyStr == "" {
log.Error().
Str("handler", "RegistrationHandler").
Msg("No machine ID in request")
http.Error(writer, "No machine ID in request", http.StatusBadRequest)
return
}
body, _ := io.ReadAll(req.Body)
var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr)))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse machine key")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
http.Error(writer, "Cannot parse machine key", http.StatusBadRequest)
return
}
registerRequest := tailcfg.RegisterRequest{}
err = decode(body, &registerRequest, &machineKey, h.privateKey)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot decode message")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
http.Error(writer, "Cannot decode message", http.StatusBadRequest)
return
}
now := time.Now().UTC()
machine, err := h.GetMachineByMachineKey(machineKey)
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Info().Str("machine", registerRequest.Hostinfo.Hostname).Msg("New machine")
machineKeyStr := MachinePublicKeyStripPrefix(machineKey)
// If the machine has AuthKey set, handle registration via PreAuthKeys
if registerRequest.Auth.AuthKey != "" {
h.handleAuthKey(writer, req, machineKey, registerRequest)
return
}
givenName, err := h.GenerateGivenName(registerRequest.Hostinfo.Hostname)
if err != nil {
log.Error().
Caller().
Str("func", "RegistrationHandler").
Str("hostinfo.name", registerRequest.Hostinfo.Hostname).
Err(err)
return
}
// The machine did not have a key to authenticate, which means
// that we rely on a method that calls back some how (OpenID or CLI)
// We create the machine and then keep it around until a callback
// happens
newMachine := Machine{
MachineKey: machineKeyStr,
Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName,
NodeKey: NodePublicKeyStripPrefix(registerRequest.NodeKey),
LastSeen: &now,
Expiry: &time.Time{},
}
if !registerRequest.Expiry.IsZero() {
log.Trace().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Time("expiry", registerRequest.Expiry).
Msg("Non-zero expiry time requested")
newMachine.Expiry = &registerRequest.Expiry
}
h.registrationCache.Set(
machineKeyStr,
newMachine,
registerCacheExpiration,
)
h.handleMachineRegistrationNew(writer, req, machineKey, registerRequest)
return
}
// The machine is already registered, so we need to pass through reauth or key update.
if machine != nil {
// If the NodeKey stored in headscale is the same as the key presented in a registration
// request, then we have a node that is either:
// - Trying to log out (sending a expiry in the past)
// - A valid, registered machine, looking for the node map
// - Expired machine wanting to reauthenticate
if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.NodeKey) {
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !registerRequest.Expiry.IsZero() && registerRequest.Expiry.UTC().Before(now) {
h.handleMachineLogOut(writer, req, machineKey, *machine)
return
}
// If machine is not expired, and is register, we have a already accepted this machine,
// let it proceed with a valid registration
if !machine.isExpired() {
h.handleMachineValidRegistration(writer, req, machineKey, *machine)
return
}
}
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.OldNodeKey) &&
!machine.isExpired() {
h.handleMachineRefreshKey(writer, req, machineKey, registerRequest, *machine)
return
}
// The machine has expired
h.handleMachineExpired(writer, req, machineKey, registerRequest, *machine)
return
}
}
func (h *Headscale) getMapResponse(
machineKey key.MachinePublic,
mapRequest tailcfg.MapRequest,
machine *Machine,
) ([]byte, error) {
log.Trace().
Str("func", "getMapResponse").
Str("machine", mapRequest.Hostinfo.Hostname).
Msg("Creating Map response")
node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
if err != nil {
log.Error().
Caller().
Str("func", "getMapResponse").
Err(err).
Msg("Cannot convert to node")
return nil, err
}
peers, err := h.getValidPeers(machine)
if err != nil {
log.Error().
Caller().
Str("func", "getMapResponse").
Err(err).
Msg("Cannot fetch peers")
return nil, err
}
profiles := getMapResponseUserProfiles(*machine, peers)
nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
if err != nil {
log.Error().
Caller().
Str("func", "getMapResponse").
Err(err).
Msg("Failed to convert peers to Tailscale nodes")
return nil, err
}
dnsConfig := getMapResponseDNSConfig(
h.cfg.DNSConfig,
h.cfg.BaseDomain,
*machine,
peers,
)
resp := tailcfg.MapResponse{
KeepAlive: false,
Node: node,
Peers: nodePeers,
DNSConfig: dnsConfig,
Domain: h.cfg.BaseDomain,
PacketFilter: h.aclRules,
DERPMap: h.DERPMap,
UserProfiles: profiles,
Debug: &tailcfg.Debug{
DisableLogTail: !h.cfg.LogTail.Enabled,
RandomizeClientPort: h.cfg.RandomizeClientPort,
},
}
log.Trace().
Str("func", "getMapResponse").
Str("machine", mapRequest.Hostinfo.Hostname).
// Interface("payload", resp).
Msgf("Generated map response: %s", tailMapResponseToString(resp))
var respBody []byte
if mapRequest.Compress == "zstd" {
src, err := json.Marshal(resp)
if err != nil {
log.Error().
Caller().
Str("func", "getMapResponse").
Err(err).
Msg("Failed to marshal response for the client")
return nil, err
}
encoder, _ := zstd.NewWriter(nil)
srcCompressed := encoder.EncodeAll(src, nil)
respBody = h.privateKey.SealTo(machineKey, srcCompressed)
} else {
respBody, err = encode(resp, &machineKey, h.privateKey)
if err != nil {
return nil, err
}
}
// declare the incoming size on the first 4 bytes
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...)
return data, nil
}
func (h *Headscale) getMapKeepAliveResponse(
machineKey key.MachinePublic,
mapRequest tailcfg.MapRequest,
) ([]byte, error) {
mapResponse := tailcfg.MapResponse{
KeepAlive: true,
}
var respBody []byte
var err error
if mapRequest.Compress == "zstd" {
src, err := json.Marshal(mapResponse)
if err != nil {
log.Error().
Caller().
Str("func", "getMapKeepAliveResponse").
Err(err).
Msg("Failed to marshal keepalive response for the client")
return nil, err
}
encoder, _ := zstd.NewWriter(nil)
srcCompressed := encoder.EncodeAll(src, nil)
respBody = h.privateKey.SealTo(machineKey, srcCompressed)
} else {
respBody, err = encode(mapResponse, &machineKey, h.privateKey)
if err != nil {
return nil, err
}
}
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...)
return data, nil
}
func (h *Headscale) handleMachineLogOut(
writer http.ResponseWriter,
req *http.Request,
machineKey key.MachinePublic,
machine Machine,
) {
resp := tailcfg.RegisterResponse{}
log.Info().
Str("machine", machine.Hostname).
Msg("Client requested logout")
err := h.ExpireMachine(&machine)
if err != nil {
log.Error().
Caller().
Str("func", "handleMachineLogOut").
Err(err).
Msg("Failed to expire machine")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
resp.AuthURL = ""
resp.MachineAuthorized = false
resp.User = *machine.Namespace.toUser()
respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot encode message")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
}
func (h *Headscale) handleMachineValidRegistration(
writer http.ResponseWriter,
req *http.Request,
machineKey key.MachinePublic,
machine Machine,
) {
resp := tailcfg.RegisterResponse{}
// The machine registration is valid, respond with redirect to /map
log.Debug().
Str("machine", machine.Hostname).
Msg("Client is registered and we have the current NodeKey. All clear to /map")
resp.AuthURL = ""
resp.MachineAuthorized = true
resp.User = *machine.Namespace.toUser()
resp.Login = *machine.Namespace.toLogin()
respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot encode message")
machineRegistrations.WithLabelValues("update", "web", "error", machine.Namespace.Name).
Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name).
Inc()
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
}
func (h *Headscale) handleMachineExpired(
writer http.ResponseWriter,
req *http.Request,
machineKey key.MachinePublic,
registerRequest tailcfg.RegisterRequest,
machine Machine,
) {
resp := tailcfg.RegisterResponse{}
// The client has registered before, but has expired
log.Debug().
Str("machine", machine.Hostname).
Msg("Machine registration has expired. Sending a authurl to register")
if registerRequest.Auth.AuthKey != "" {
h.handleAuthKey(writer, req, machineKey, registerRequest)
return
}
if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.String())
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.String())
}
respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot encode message")
machineRegistrations.WithLabelValues("reauth", "web", "error", machine.Namespace.Name).
Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
machineRegistrations.WithLabelValues("reauth", "web", "success", machine.Namespace.Name).
Inc()
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
}
func (h *Headscale) handleMachineRefreshKey(
writer http.ResponseWriter,
req *http.Request,
machineKey key.MachinePublic,
registerRequest tailcfg.RegisterRequest,
machine Machine,
) {
resp := tailcfg.RegisterResponse{}
log.Debug().
Str("machine", machine.Hostname).
Msg("We have the OldNodeKey in the database. This is a key refresh")
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
if err := h.db.Save(&machine).Error; err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to update machine key in the database")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
resp.AuthURL = ""
resp.User = *machine.Namespace.toUser()
respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot encode message")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
}
func (h *Headscale) handleMachineRegistrationNew(
writer http.ResponseWriter,
req *http.Request,
machineKey key.MachinePublic,
registerRequest tailcfg.RegisterRequest,
) {
resp := tailcfg.RegisterResponse{}
// The machine registration is new, redirect the client to the registration URL
log.Debug().
Str("machine", registerRequest.Hostinfo.Hostname).
Msg("The node is sending us a new NodeKey, sending auth url")
if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf(
"%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
machineKey.String(),
)
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey))
}
respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot encode message")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
}
// TODO: check if any locks are needed around IP allocation.
func (h *Headscale) handleAuthKey(
writer http.ResponseWriter,
req *http.Request,
machineKey key.MachinePublic,
registerRequest tailcfg.RegisterRequest,
) {
machineKeyStr := MachinePublicKeyStripPrefix(machineKey)
log.Debug().
Str("func", "handleAuthKey").
Str("machine", registerRequest.Hostinfo.Hostname).
Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname)
resp := tailcfg.RegisterResponse{}
pak, err := h.checkKeyValidity(registerRequest.Auth.AuthKey)
if err != nil {
log.Error().
Caller().
Str("func", "handleAuthKey").
Str("machine", registerRequest.Hostinfo.Hostname).
Err(err).
Msg("Failed authentication via AuthKey")
resp.MachineAuthorized = false
respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Caller().
Str("func", "handleAuthKey").
Str("machine", registerRequest.Hostinfo.Hostname).
Err(err).
Msg("Cannot encode message")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc()
return
}
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusUnauthorized)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
log.Error().
Caller().
Str("func", "handleAuthKey").
Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Failed authentication via AuthKey")
if pak != nil {
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc()
} else {
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", "unknown").Inc()
}
return
}
log.Debug().
Str("func", "handleAuthKey").
Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Authentication key was valid, proceeding to acquire IP addresses")
nodeKey := NodePublicKeyStripPrefix(registerRequest.NodeKey)
// retrieve machine information if it exist
// The error is not important, because if it does not
// exist, then this is a new machine and we will move
// on to registration.
machine, _ := h.GetMachineByMachineKey(machineKey)
if machine != nil {
log.Trace().
Caller().
Str("machine", machine.Hostname).
Msg("machine already registered, refreshing with new auth key")
machine.NodeKey = nodeKey
machine.AuthKeyID = uint(pak.ID)
err := h.RefreshMachine(machine, registerRequest.Expiry)
if err != nil {
log.Error().
Caller().
Str("machine", machine.Hostname).
Err(err).
Msg("Failed to refresh machine")
return
}
} else {
now := time.Now().UTC()
givenName, err := h.GenerateGivenName(registerRequest.Hostinfo.Hostname)
if err != nil {
log.Error().
Caller().
Str("func", "RegistrationHandler").
Str("hostinfo.name", registerRequest.Hostinfo.Hostname).
Err(err)
return
}
machineToRegister := Machine{
Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName,
NamespaceID: pak.Namespace.ID,
MachineKey: machineKeyStr,
RegisterMethod: RegisterMethodAuthKey,
Expiry: &registerRequest.Expiry,
NodeKey: nodeKey,
LastSeen: &now,
AuthKeyID: uint(pak.ID),
}
machine, err = h.RegisterMachine(
machineToRegister,
)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("could not register machine")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
}
err = h.UsePreAuthKey(pak)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to use pre-auth key")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
resp.MachineAuthorized = true
resp.User = *pak.Namespace.toUser()
respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil {
log.Error().
Caller().
Str("func", "handleAuthKey").
Str("machine", registerRequest.Hostinfo.Hostname).
Err(err).
Msg("Cannot encode message")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.Namespace.Name).
Inc()
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
log.Info().
Str("func", "handleAuthKey").
Str("machine", registerRequest.Hostinfo.Hostname).
Str("ips", strings.Join(machine.IPAddresses.ToStringSlice(), ", ")).
Msg("Successfully authenticated via AuthKey")
}

View File

@@ -1,80 +0,0 @@
package headscale
import (
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
)
func (h *Headscale) generateMapResponse(
mapRequest tailcfg.MapRequest,
machine *Machine,
) (*tailcfg.MapResponse, error) {
log.Trace().
Str("func", "generateMapResponse").
Str("machine", mapRequest.Hostinfo.Hostname).
Msg("Creating Map response")
node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
if err != nil {
log.Error().
Caller().
Str("func", "generateMapResponse").
Err(err).
Msg("Cannot convert to node")
return nil, err
}
peers, err := h.getValidPeers(machine)
if err != nil {
log.Error().
Caller().
Str("func", "generateMapResponse").
Err(err).
Msg("Cannot fetch peers")
return nil, err
}
profiles := getMapResponseUserProfiles(*machine, peers)
nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
if err != nil {
log.Error().
Caller().
Str("func", "generateMapResponse").
Err(err).
Msg("Failed to convert peers to Tailscale nodes")
return nil, err
}
dnsConfig := getMapResponseDNSConfig(
h.cfg.DNSConfig,
h.cfg.BaseDomain,
*machine,
peers,
)
resp := tailcfg.MapResponse{
KeepAlive: false,
Node: node,
Peers: nodePeers,
DNSConfig: dnsConfig,
Domain: h.cfg.BaseDomain,
PacketFilter: h.aclRules,
DERPMap: h.DERPMap,
UserProfiles: profiles,
Debug: &tailcfg.Debug{
DisableLogTail: !h.cfg.LogTail.Enabled,
RandomizeClientPort: h.cfg.RandomizeClientPort,
},
}
log.Trace().
Str("func", "generateMapResponse").
Str("machine", mapRequest.Hostinfo.Hostname).
// Interface("payload", resp).
Msgf("Generated map response: %s", tailMapResponseToString(resp))
return &resp, nil
}

View File

@@ -14,7 +14,7 @@ const (
apiPrefixLength = 7
apiKeyLength = 32
ErrAPIKeyFailedToParse = Error("Failed to parse ApiKey")
errAPIKeyFailedToParse = Error("Failed to parse ApiKey")
)
// APIKey describes the datamodel for API keys used to remotely authenticate with
@@ -116,7 +116,7 @@ func (h *Headscale) ExpireAPIKey(key *APIKey) error {
func (h *Headscale) ValidateAPIKey(keyStr string) (bool, error) {
prefix, hash, found := strings.Cut(keyStr, ".")
if !found {
return false, ErrAPIKeyFailedToParse
return false, errAPIKeyFailedToParse
}
key, err := h.GetAPIKey(prefix)

111
app.go
View File

@@ -18,7 +18,7 @@ import (
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/mux"
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/patrickmn/go-cache"
@@ -51,10 +51,6 @@ const (
errUnsupportedLetsEncryptChallengeType = Error(
"unknown value for Lets Encrypt challenge type",
)
ErrFailedPrivateKey = Error("failed to read or create private key")
ErrFailedNoisePrivateKey = Error("failed to read or create Noise protocol private key")
ErrSamePrivateKeys = Error("private key and noise private key are the same")
)
const (
@@ -76,15 +72,12 @@ const (
// Headscale represents the base app of the service.
type Headscale struct {
cfg *Config
db *gorm.DB
dbString string
dbType string
dbDebug bool
privateKey *key.MachinePrivate
noisePrivateKey *key.MachinePrivate
noiseMux *mux.Router
cfg *Config
db *gorm.DB
dbString string
dbType string
dbDebug bool
privateKey *key.MachinePrivate
DERPMap *tailcfg.DERPMap
DERPServer *DERPServer
@@ -127,42 +120,22 @@ func LookupTLSClientAuthMode(mode string) (tls.ClientAuthType, bool) {
}
func NewHeadscale(cfg *Config) (*Headscale, error) {
privateKey, err := readOrCreatePrivateKey(cfg.PrivateKeyPath)
privKey, err := readOrCreatePrivateKey(cfg.PrivateKeyPath)
if err != nil {
return nil, ErrFailedPrivateKey
}
// TS2021 requires to have a different key from the legacy protocol.
noisePrivateKey, err := readOrCreatePrivateKey(cfg.NoisePrivateKeyPath)
if err != nil {
return nil, ErrFailedNoisePrivateKey
}
if privateKey.Equal(*noisePrivateKey) {
return nil, ErrSamePrivateKeys
return nil, fmt.Errorf("failed to read or create private key: %w", err)
}
var dbString string
switch cfg.DBtype {
case Postgres:
dbString = fmt.Sprintf(
"host=%s dbname=%s user=%s",
"host=%s port=%d dbname=%s user=%s password=%s sslmode=disable",
cfg.DBhost,
cfg.DBport,
cfg.DBname,
cfg.DBuser,
cfg.DBpass,
)
if !cfg.DBssl {
dbString += " sslmode=disable"
}
if cfg.DBport != 0 {
dbString += fmt.Sprintf(" port=%d", cfg.DBport)
}
if cfg.DBpass != "" {
dbString += fmt.Sprintf(" password=%s", cfg.DBpass)
}
case Sqlite:
dbString = cfg.DBpath
default:
@@ -178,8 +151,7 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
cfg: cfg,
dbType: cfg.DBtype,
dbString: dbString,
privateKey: privateKey,
noisePrivateKey: noisePrivateKey,
privateKey: privKey,
aclRules: tailcfg.FilterAllowAll, // default allowall
registrationCache: registrationCache,
pollNetMapStreamWG: sync.WaitGroup{},
@@ -275,7 +247,7 @@ func (h *Headscale) expireEphemeralNodesWorker() {
}
if expiredFound {
h.setLastStateChangeToNow()
h.setLastStateChangeToNow(namespace.Name)
}
}
}
@@ -443,14 +415,12 @@ func (h *Headscale) ensureUnixSocketIsAbsent() error {
func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *mux.Router {
router := mux.NewRouter()
router.HandleFunc(ts2021UpgradePath, h.NoiseUpgradeHandler).Methods(http.MethodPost)
router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet)
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
router.HandleFunc("/register/{nkey}", h.RegisterWebAPI).Methods(http.MethodGet)
router.HandleFunc("/register", h.RegisterWebAPI).Methods(http.MethodGet)
router.HandleFunc("/machine/{mkey}/map", h.PollNetMapHandler).Methods(http.MethodPost)
router.HandleFunc("/machine/{mkey}", h.RegistrationHandler).Methods(http.MethodPost)
router.HandleFunc("/oidc/register/{nkey}", h.RegisterOIDC).Methods(http.MethodGet)
router.HandleFunc("/oidc/register/{mkey}", h.RegisterOIDC).Methods(http.MethodGet)
router.HandleFunc("/oidc/callback", h.OIDCCallback).Methods(http.MethodGet)
router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet)
router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig).Methods(http.MethodGet)
@@ -474,15 +444,6 @@ func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *mux.Router {
return router
}
func (h *Headscale) createNoiseMux() *mux.Router {
router := mux.NewRouter()
router.HandleFunc("/machine/register", h.NoiseRegistrationHandler).Methods(http.MethodPost)
router.HandleFunc("/machine/map", h.NoisePollNetMapHandler)
return router
}
// Serve launches a GIN server with the Headscale API.
func (h *Headscale) Serve() error {
var err error
@@ -601,7 +562,7 @@ func (h *Headscale) Serve() error {
grpcOptions := []grpc.ServerOption{
grpc.UnaryInterceptor(
grpcMiddleware.ChainUnaryServer(
grpc_middleware.ChainUnaryServer(
h.grpcAuthenticationInterceptor,
zerolog.NewUnaryServerInterceptor(),
),
@@ -636,15 +597,8 @@ func (h *Headscale) Serve() error {
//
// HTTP setup
//
// This is the regular router that we expose
// over our main Addr. It also serves the legacy Tailcale API
router := h.createRouter(grpcGatewayMux)
// This router is served only over the Noise connection, and exposes only the new API.
//
// The HTTP2 server that exposes this router is created for
// a single hijacked connection from /ts2021, using netutil.NewOneConnListener
h.noiseMux = h.createNoiseMux()
router := h.createRouter(grpcGatewayMux)
httpServer := &http.Server{
Addr: h.cfg.Addr,
@@ -738,10 +692,7 @@ func (h *Headscale) Serve() error {
h.pollNetMapStreamWG.Wait()
// Gracefully shut down servers
ctx, cancel := context.WithTimeout(
context.Background(),
HTTPShutdownTimeout,
)
ctx, cancel := context.WithTimeout(context.Background(), HTTPShutdownTimeout)
if err := promHTTPServer.Shutdown(ctx); err != nil {
log.Error().Err(err).Msg("Failed to shutdown prometheus http")
}
@@ -820,19 +771,10 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
// Configuration via autocert with HTTP-01. This requires listening on
// port 80 for the certificate validation in addition to the headscale
// service, which can be configured to run on any other port.
server := &http.Server{
Addr: h.cfg.TLS.LetsEncrypt.Listen,
Handler: certManager.HTTPHandler(http.HandlerFunc(h.redirect)),
ReadTimeout: HTTPReadTimeout,
}
err := server.ListenAndServe()
go func() {
log.Fatal().
Caller().
Err(err).
Err(http.ListenAndServe(h.cfg.TLS.LetsEncrypt.Listen, certManager.HTTPHandler(http.HandlerFunc(h.redirect)))).
Msg("failed to set up a HTTP server")
}()
@@ -869,17 +811,16 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
}
}
func (h *Headscale) setLastStateChangeToNow() {
func (h *Headscale) setLastStateChangeToNow(namespaces ...string) {
var err error
now := time.Now().UTC()
namespaces, err := h.ListNamespacesStr()
if err != nil {
log.Error().
Caller().
Err(err).
Msg("failed to fetch all namespaces, failing to update last changed state.")
if len(namespaces) == 0 {
namespaces, err = h.ListNamespacesStr()
if err != nil {
log.Error().Caller().Err(err).Msg("failed to fetch all namespaces, failing to update last changed state.")
}
}
for _, namespace := range namespaces {

View File

@@ -1,11 +1,12 @@
package headscale
import (
"net/netip"
"io/ioutil"
"os"
"testing"
"gopkg.in/check.v1"
"inet.af/netaddr"
)
func Test(t *testing.T) {
@@ -34,13 +35,13 @@ func (s *Suite) ResetDB(c *check.C) {
os.RemoveAll(tmpDir)
}
var err error
tmpDir, err = os.MkdirTemp("", "autoygg-client-test")
tmpDir, err = ioutil.TempDir("", "autoygg-client-test")
if err != nil {
c.Fatal(err)
}
cfg := Config{
IPPrefixes: []netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
IPPrefixes: []netaddr.IPPrefix{
netaddr.MustParseIPPrefix("10.27.0.0/23"),
},
}

View File

@@ -134,9 +134,7 @@ If you loose a key, create a new one and revoke (expire) the old one.`,
expiration := time.Now().UTC().Add(time.Duration(duration))
log.Trace().
Dur("expiration", time.Duration(duration)).
Msg("expiration has been set")
log.Trace().Dur("expiration", time.Duration(duration)).Msg("expiration has been set")
request.Expiration = timestamppb.New(expiration)

View File

@@ -3,7 +3,6 @@ package cli
import (
"fmt"
"log"
"net/netip"
"strconv"
"strings"
"time"
@@ -14,6 +13,7 @@ import (
"github.com/pterm/pterm"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"inet.af/netaddr"
"tailscale.com/types/key"
)
@@ -108,7 +108,7 @@ var registerNodeCmd = &cobra.Command{
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Error getting node key from flag: %s", err),
fmt.Sprintf("Error getting machine key from flag: %s", err),
output,
)
@@ -557,7 +557,7 @@ func nodesToPtables(
var IPV4Address string
var IPV6Address string
for _, addr := range machine.IpAddresses {
if netip.MustParseAddr(addr).Is4() {
if netaddr.MustParseIP(addr).Is4() {
IPV4Address = addr
} else {
IPV6Address = addr

View File

@@ -164,9 +164,7 @@ var createPreAuthKeyCmd = &cobra.Command{
expiration := time.Now().UTC().Add(time.Duration(duration))
log.Trace().
Dur("expiration", time.Duration(duration)).
Msg("expiration has been set")
log.Trace().Dur("expiration", time.Duration(duration)).Msg("expiration has been set")
request.Expiration = timestamppb.New(expiration)

View File

@@ -25,18 +25,15 @@ func init() {
}
func initConfig() {
if cfgFile == "" {
cfgFile = os.Getenv("HEADSCALE_CONFIG")
}
if cfgFile != "" {
err := headscale.LoadConfig(cfgFile, true)
if err != nil {
log.Fatal().Caller().Err(err).Msgf("Error loading config file %s", cfgFile)
log.Fatal().Caller().Err(err)
}
} else {
err := headscale.LoadConfig("", false)
if err != nil {
log.Fatal().Caller().Err(err).Msgf("Error loading config")
log.Fatal().Caller().Err(err)
}
}
@@ -47,7 +44,7 @@ func initConfig() {
machineOutput := HasMachineOutputFlag()
zerolog.SetGlobalLevel(cfg.Log.Level)
zerolog.SetGlobalLevel(cfg.LogLevel)
// If the user has requested a "machine" readable format,
// then disable login so the output remains valid.
@@ -55,10 +52,6 @@ func initConfig() {
zerolog.SetGlobalLevel(zerolog.Disabled)
}
if cfg.Log.Format == headscale.JSONLogFormat {
log.Logger = log.Output(os.Stdout)
}
if !cfg.DisableUpdateCheck && !machineOutput {
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") &&
Version != "dev" {

View File

@@ -24,10 +24,7 @@ const (
func getHeadscaleApp() (*headscale.Headscale, error) {
cfg, err := headscale.GetHeadscaleConfig()
if err != nil {
return nil, fmt.Errorf(
"failed to load configuration while creating headscale instance: %w",
err,
)
return nil, fmt.Errorf("failed to load configuration while creating headscale instance: %w", err)
}
app, err := headscale.NewHeadscale(cfg)

View File

@@ -2,6 +2,7 @@ package main
import (
"io/fs"
"io/ioutil"
"os"
"path/filepath"
"strings"
@@ -27,7 +28,7 @@ func (s *Suite) TearDownSuite(c *check.C) {
}
func (*Suite) TestConfigFileLoading(c *check.C) {
tmpDir, err := os.MkdirTemp("", "headscale")
tmpDir, err := ioutil.TempDir("", "headscale")
if err != nil {
c.Fatal(err)
}
@@ -72,7 +73,7 @@ func (*Suite) TestConfigFileLoading(c *check.C) {
}
func (*Suite) TestConfigLoading(c *check.C) {
tmpDir, err := os.MkdirTemp("", "headscale")
tmpDir, err := ioutil.TempDir("", "headscale")
if err != nil {
c.Fatal(err)
}
@@ -116,7 +117,7 @@ func (*Suite) TestConfigLoading(c *check.C) {
}
func (*Suite) TestDNSConfigLoading(c *check.C) {
tmpDir, err := os.MkdirTemp("", "headscale")
tmpDir, err := ioutil.TempDir("", "headscale")
if err != nil {
c.Fatal(err)
}
@@ -151,24 +152,22 @@ func (*Suite) TestDNSConfigLoading(c *check.C) {
func writeConfig(c *check.C, tmpDir string, configYaml []byte) {
// Populate a custom config file
configFile := filepath.Join(tmpDir, "config.yaml")
err := os.WriteFile(configFile, configYaml, 0o600)
err := ioutil.WriteFile(configFile, configYaml, 0o600)
if err != nil {
c.Fatalf("Couldn't write file %s", configFile)
}
}
func (*Suite) TestTLSConfigValidation(c *check.C) {
tmpDir, err := os.MkdirTemp("", "headscale")
tmpDir, err := ioutil.TempDir("", "headscale")
if err != nil {
c.Fatal(err)
}
// defer os.RemoveAll(tmpDir)
configYaml := []byte(`---
tls_letsencrypt_hostname: example.com
tls_letsencrypt_challenge_type: ""
tls_cert_path: abc.pem
noise:
private_key_path: noise_private.key`)
configYaml := []byte(
"---\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"\"\ntls_cert_path: \"abc.pem\"",
)
writeConfig(c, tmpDir, configYaml)
// Check configuration validation errors (1)
@@ -193,13 +192,9 @@ noise:
)
// Check configuration validation errors (2)
configYaml = []byte(`---
noise:
private_key_path: noise_private.key
server_url: http://127.0.0.1:8080
tls_letsencrypt_hostname: example.com
tls_letsencrypt_challenge_type: TLS-ALPN-01
`)
configYaml = []byte(
"---\nserver_url: \"http://127.0.0.1:8080\"\ntls_letsencrypt_hostname: \"example.com\"\ntls_letsencrypt_challenge_type: \"TLS-ALPN-01\"",
)
writeConfig(c, tmpDir, configYaml)
err = headscale.LoadConfig(tmpDir, false)
c.Assert(err, check.IsNil)

View File

@@ -41,15 +41,6 @@ grpc_allow_insecure: false
# autogenerated if it's missing
private_key_path: /var/lib/headscale/private.key
# The Noise section includes specific configuration for the
# TS2021 Noise procotol
noise:
# The Noise private key is used to encrypt the
# traffic between headscale and Tailscale clients when
# using the new Noise-based protocol. It must be different
# from the legacy private key.
private_key_path: /var/lib/headscale/noise_private.key
# List of IP prefixes to allocate tailaddresses from.
# Each prefix consists of either an IPv4 or IPv6 address,
# and the associated prefix length, delimited by a slash.
@@ -123,14 +114,12 @@ db_type: sqlite3
db_path: /var/lib/headscale/db.sqlite
# # Postgres config
# If using a Unix socket to connect to Postgres, set the socket path in the 'host' field and leave 'port' blank.
# db_type: postgres
# db_host: localhost
# db_port: 5432
# db_name: headscale
# db_user: foo
# db_pass: bar
# db_ssl: false
### TLS configuration
#
@@ -172,10 +161,7 @@ tls_letsencrypt_listen: ":http"
tls_cert_path: ""
tls_key_path: ""
log:
# Output formatting for logs: text or json
format: text
level: info
log_level: info
# Path to a file containg ACL policies.
# ACLs can be defined as YAML or HUJSON.

View File

@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"io/fs"
"net/netip"
"net/url"
"strings"
"time"
@@ -14,7 +13,7 @@ import (
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/spf13/viper"
"go4.org/netipx"
"inet.af/netaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
)
@@ -22,9 +21,6 @@ import (
const (
tlsALPN01ChallengeType = "TLS-ALPN-01"
http01ChallengeType = "HTTP-01"
JSONLogFormat = "json"
TextLogFormat = "text"
)
// Config contains the initial Headscale configuration.
@@ -36,11 +32,10 @@ type Config struct {
GRPCAllowInsecure bool
EphemeralNodeInactivityTimeout time.Duration
NodeUpdateCheckInterval time.Duration
IPPrefixes []netip.Prefix
IPPrefixes []netaddr.IPPrefix
PrivateKeyPath string
NoisePrivateKeyPath string
BaseDomain string
Log LogConfig
LogLevel zerolog.Level
DisableUpdateCheck bool
DERP DERPConfig
@@ -52,7 +47,6 @@ type Config struct {
DBname string
DBuser string
DBpass string
DBssl bool
TLS TLSConfig
@@ -127,11 +121,6 @@ type ACLConfig struct {
PolicyPath string
}
type LogConfig struct {
Format string
Level zerolog.Level
}
func LoadConfig(path string, isFile bool) error {
if isFile {
viper.SetConfigFile(path)
@@ -155,8 +144,7 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("tls_letsencrypt_challenge_type", http01ChallengeType)
viper.SetDefault("tls_client_auth_mode", "relaxed")
viper.SetDefault("log.level", "info")
viper.SetDefault("log.format", TextLogFormat)
viper.SetDefault("log_level", "info")
viper.SetDefault("dns_config", nil)
@@ -195,10 +183,6 @@ func LoadConfig(path string, isFile bool) error {
errorText += "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both\n"
}
if !viper.IsSet("noise") || viper.GetString("noise.private_key_path") == "" {
errorText += "Fatal config error: headscale now requires a new `noise.private_key_path` field in the config file for the Tailscale v2 protocol\n"
}
if (viper.GetString("tls_letsencrypt_hostname") != "") &&
(viper.GetString("tls_letsencrypt_challenge_type") == tlsALPN01ChallengeType) &&
(!strings.HasSuffix(viper.GetString("listen_addr"), ":443")) {
@@ -343,34 +327,6 @@ func GetACLConfig() ACLConfig {
}
}
func GetLogConfig() LogConfig {
logLevelStr := viper.GetString("log.level")
logLevel, err := zerolog.ParseLevel(logLevelStr)
if err != nil {
logLevel = zerolog.DebugLevel
}
logFormatOpt := viper.GetString("log.format")
var logFormat string
switch logFormatOpt {
case "json":
logFormat = JSONLogFormat
case "text":
logFormat = TextLogFormat
case "":
logFormat = TextLogFormat
default:
log.Error().
Str("func", "GetLogConfig").
Msgf("Could not parse log format: %s. Valid choices are 'json' or 'text'", logFormatOpt)
}
return LogConfig{
Format: logFormat,
Level: logLevel,
}
}
func GetDNSConfig() (*tailcfg.DNSConfig, string) {
if viper.IsSet("dns_config") {
dnsConfig := &tailcfg.DNSConfig{}
@@ -378,11 +334,11 @@ func GetDNSConfig() (*tailcfg.DNSConfig, string) {
if viper.IsSet("dns_config.nameservers") {
nameserversStr := viper.GetStringSlice("dns_config.nameservers")
nameservers := make([]netip.Addr, len(nameserversStr))
nameservers := make([]netaddr.IP, len(nameserversStr))
resolvers := make([]*dnstype.Resolver, len(nameserversStr))
for index, nameserverStr := range nameserversStr {
nameserver, err := netip.ParseAddr(nameserverStr)
nameserver, err := netaddr.ParseIP(nameserverStr)
if err != nil {
log.Error().
Str("func", "getDNSConfig").
@@ -412,7 +368,7 @@ func GetDNSConfig() (*tailcfg.DNSConfig, string) {
len(restrictedNameservers),
)
for index, nameserverStr := range restrictedNameservers {
nameserver, err := netip.ParseAddr(nameserverStr)
nameserver, err := netaddr.ParseIP(nameserverStr)
if err != nil {
log.Error().
Str("func", "getDNSConfig").
@@ -465,7 +421,13 @@ func GetHeadscaleConfig() (*Config, error) {
randomizeClientPort := viper.GetBool("randomize_client_port")
configuredPrefixes := viper.GetStringSlice("ip_prefixes")
parsedPrefixes := make([]netip.Prefix, 0, len(configuredPrefixes)+1)
parsedPrefixes := make([]netaddr.IPPrefix, 0, len(configuredPrefixes)+1)
logLevelStr := viper.GetString("log_level")
logLevel, err := zerolog.ParseLevel(logLevelStr)
if err != nil {
logLevel = zerolog.DebugLevel
}
legacyPrefixField := viper.GetString("ip_prefix")
if len(legacyPrefixField) > 0 {
@@ -476,7 +438,7 @@ func GetHeadscaleConfig() (*Config, error) {
"use of 'ip_prefix' for configuration is deprecated",
"please see 'ip_prefixes' in the shipped example.",
)
legacyPrefix, err := netip.ParsePrefix(legacyPrefixField)
legacyPrefix, err := netaddr.ParseIPPrefix(legacyPrefixField)
if err != nil {
panic(fmt.Errorf("failed to parse ip_prefix: %w", err))
}
@@ -484,19 +446,19 @@ func GetHeadscaleConfig() (*Config, error) {
}
for i, prefixInConfig := range configuredPrefixes {
prefix, err := netip.ParsePrefix(prefixInConfig)
prefix, err := netaddr.ParseIPPrefix(prefixInConfig)
if err != nil {
panic(fmt.Errorf("failed to parse ip_prefixes[%d]: %w", i, err))
}
parsedPrefixes = append(parsedPrefixes, prefix)
}
prefixes := make([]netip.Prefix, 0, len(parsedPrefixes))
prefixes := make([]netaddr.IPPrefix, 0, len(parsedPrefixes))
{
// dedup
normalizedPrefixes := make(map[string]int, len(parsedPrefixes))
for i, p := range parsedPrefixes {
normalized, _ := netipx.RangeOfPrefix(p).Prefix()
normalized, _ := p.Range().Prefix()
normalizedPrefixes[normalized.String()] = i
}
@@ -507,7 +469,7 @@ func GetHeadscaleConfig() (*Config, error) {
}
if len(prefixes) < 1 {
prefixes = append(prefixes, netip.MustParsePrefix("100.64.0.0/10"))
prefixes = append(prefixes, netaddr.MustParseIPPrefix("100.64.0.0/10"))
log.Warn().
Msgf("'ip_prefixes' not configured, falling back to default: %v", prefixes)
}
@@ -519,14 +481,12 @@ func GetHeadscaleConfig() (*Config, error) {
GRPCAddr: viper.GetString("grpc_listen_addr"),
GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"),
DisableUpdateCheck: viper.GetBool("disable_check_updates"),
LogLevel: logLevel,
IPPrefixes: prefixes,
PrivateKeyPath: AbsolutePathFromConfigPath(
viper.GetString("private_key_path"),
),
NoisePrivateKeyPath: AbsolutePathFromConfigPath(
viper.GetString("noise.private_key_path"),
),
BaseDomain: baseDomain,
DERP: derpConfig,
@@ -546,7 +506,6 @@ func GetHeadscaleConfig() (*Config, error) {
DBname: viper.GetString("db_name"),
DBuser: viper.GetString("db_user"),
DBpass: viper.GetString("db_pass"),
DBssl: viper.GetBool("db_ssl"),
TLS: GetTLSConfig(),
@@ -580,7 +539,5 @@ func GetHeadscaleConfig() (*Config, error) {
},
ACL: GetACLConfig(),
Log: GetLogConfig(),
}, nil
}

14
db.go
View File

@@ -6,7 +6,6 @@ import (
"encoding/json"
"errors"
"fmt"
"net/netip"
"time"
"github.com/glebarez/sqlite"
@@ -14,6 +13,7 @@ import (
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"inet.af/netaddr"
"tailscale.com/tailcfg"
)
@@ -221,8 +221,8 @@ func (h *Headscale) setValue(key string, value string) error {
return nil
}
func (h *Headscale) pingDB(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, time.Second)
func (h *Headscale) pingDB() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
db, err := h.db.DB()
if err != nil {
@@ -248,7 +248,7 @@ func (hi *HostInfo) Scan(destination interface{}) error {
return json.Unmarshal([]byte(value), hi)
default:
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
return fmt.Errorf("%w: unexpected data type %T", errMachineAddressesInvalid, destination)
}
}
@@ -259,7 +259,7 @@ func (hi HostInfo) Value() (driver.Value, error) {
return string(bytes), err
}
type IPPrefixes []netip.Prefix
type IPPrefixes []netaddr.IPPrefix
func (i *IPPrefixes) Scan(destination interface{}) error {
switch value := destination.(type) {
@@ -270,7 +270,7 @@ func (i *IPPrefixes) Scan(destination interface{}) error {
return json.Unmarshal([]byte(value), i)
default:
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
return fmt.Errorf("%w: unexpected data type %T", errMachineAddressesInvalid, destination)
}
}
@@ -292,7 +292,7 @@ func (i *StringList) Scan(destination interface{}) error {
return json.Unmarshal([]byte(value), i)
default:
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
return fmt.Errorf("%w: unexpected data type %T", errMachineAddressesInvalid, destination)
}
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"io"
"io/ioutil"
"net/http"
"net/url"
"os"
@@ -34,7 +35,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
ctx, cancel := context.WithTimeout(context.Background(), HTTPReadTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, addr.String(), nil)
req, err := http.NewRequestWithContext(ctx, "GET", addr.String(), nil)
if err != nil {
return nil, err
}
@@ -49,7 +50,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"net"
"net/http"
"net/netip"
"net/url"
"strconv"
"strings"
@@ -99,13 +98,10 @@ func (h *Headscale) DERPHandler(
req *http.Request,
) {
log.Trace().Caller().Msgf("/derp request from %v", req.RemoteAddr)
upgrade := strings.ToLower(req.Header.Get("Upgrade"))
if upgrade != "websocket" && upgrade != "derp" {
if upgrade != "" {
log.Warn().
Caller().
Msg("No Upgrade header in DERP server request. If headscale is behind a reverse proxy, make sure it is configured to pass WebSockets through.")
up := strings.ToLower(req.Header.Get("Upgrade"))
if up != "websocket" && up != "derp" {
if up != "" {
log.Warn().Caller().Msgf("Weird websockets connection upgrade: %q", up)
}
writer.Header().Set("Content-Type", "text/plain")
writer.WriteHeader(http.StatusUpgradeRequired)
@@ -157,7 +153,7 @@ func (h *Headscale) DERPHandler(
if !fastStart {
pubKey := h.privateKey.Public()
pubKeyStr := pubKey.UntypedHexString() //nolint
pubKeyStr := pubKey.UntypedHexString() // nolint
fmt.Fprintf(conn, "HTTP/1.1 101 Switching Protocols\r\n"+
"Upgrade: DERP\r\n"+
"Connection: Upgrade\r\n"+
@@ -167,7 +163,7 @@ func (h *Headscale) DERPHandler(
pubKeyStr)
}
h.DERPServer.tailscaleDERP.Accept(req.Context(), netConn, conn, netConn.RemoteAddr().String())
h.DERPServer.tailscaleDERP.Accept(netConn, conn, netConn.RemoteAddr().String())
}
// DERPProbeHandler is the endpoint that js/wasm clients hit to measure
@@ -177,7 +173,7 @@ func (h *Headscale) DERPProbeHandler(
req *http.Request,
) {
switch req.Method {
case http.MethodHead, http.MethodGet:
case "HEAD", "GET":
writer.Header().Set("Access-Control-Allow-Origin", "*")
writer.WriteHeader(http.StatusOK)
default:
@@ -205,7 +201,7 @@ func (h *Headscale) DERPBootstrapDNSHandler(
) {
dnsEntries := make(map[string][]net.IP)
resolvCtx, cancel := context.WithTimeout(req.Context(), time.Minute)
resolvCtx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var resolver net.Resolver
for _, region := range h.DERPMap.Regions {
@@ -280,8 +276,7 @@ func serverSTUNListener(ctx context.Context, packetConn *net.UDPConn) {
continue
}
addr, _ := netip.AddrFromSlice(udpAddr.IP)
res := stun.Response(txid, netip.AddrPortFrom(addr, uint16(udpAddr.Port)))
res := stun.Response(txid, udpAddr.IP, uint16(udpAddr.Port))
_, err = packetConn.WriteTo(res, udpAddr)
if err != nil {
log.Trace().Caller().Err(err).Msgf("Issue writing to UDP")

21
dns.go
View File

@@ -2,11 +2,10 @@ package headscale
import (
"fmt"
"net/netip"
"strings"
mapset "github.com/deckarep/golang-set/v2"
"go4.org/netipx"
"inet.af/netaddr"
"tailscale.com/tailcfg"
"tailscale.com/util/dnsname"
)
@@ -40,11 +39,11 @@ const (
// From the netmask we can find out the wildcard bits (the bits that are not set in the netmask).
// This allows us to then calculate the subnets included in the subsequent class block and generate the entries.
func generateMagicDNSRootDomains(ipPrefixes []netip.Prefix) []dnsname.FQDN {
func generateMagicDNSRootDomains(ipPrefixes []netaddr.IPPrefix) []dnsname.FQDN {
fqdns := make([]dnsname.FQDN, 0, len(ipPrefixes))
for _, ipPrefix := range ipPrefixes {
var generateDNSRoot func(netip.Prefix) []dnsname.FQDN
switch ipPrefix.Addr().BitLen() {
var generateDNSRoot func(netaddr.IPPrefix) []dnsname.FQDN
switch ipPrefix.IP().BitLen() {
case ipv4AddressLength:
generateDNSRoot = generateIPv4DNSRootDomain
@@ -55,7 +54,7 @@ func generateMagicDNSRootDomains(ipPrefixes []netip.Prefix) []dnsname.FQDN {
panic(
fmt.Sprintf(
"unsupported IP version with address length %d",
ipPrefix.Addr().BitLen(),
ipPrefix.IP().BitLen(),
),
)
}
@@ -66,9 +65,9 @@ func generateMagicDNSRootDomains(ipPrefixes []netip.Prefix) []dnsname.FQDN {
return fqdns
}
func generateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
func generateIPv4DNSRootDomain(ipPrefix netaddr.IPPrefix) []dnsname.FQDN {
// Conversion to the std lib net.IPnet, a bit easier to operate
netRange := netipx.PrefixIPNet(ipPrefix)
netRange := ipPrefix.IPNet()
maskBits, _ := netRange.Mask.Size()
// lastOctet is the last IP byte covered by the mask
@@ -102,11 +101,11 @@ func generateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
return fqdns
}
func generateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
func generateIPv6DNSRootDomain(ipPrefix netaddr.IPPrefix) []dnsname.FQDN {
const nibbleLen = 4
maskBits, _ := netipx.PrefixIPNet(ipPrefix).Mask.Size()
expanded := ipPrefix.Addr().StringExpanded()
maskBits, _ := ipPrefix.IPNet().Mask.Size()
expanded := ipPrefix.IP().StringExpanded()
nibbleStr := strings.Map(func(r rune) rune {
if r == ':' {
return -1

View File

@@ -2,16 +2,16 @@ package headscale
import (
"fmt"
"net/netip"
"gopkg.in/check.v1"
"inet.af/netaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
)
func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
prefixes := []netip.Prefix{
netip.MustParsePrefix("100.64.0.0/10"),
prefixes := []netaddr.IPPrefix{
netaddr.MustParseIPPrefix("100.64.0.0/10"),
}
domains := generateMagicDNSRootDomains(prefixes)
@@ -47,8 +47,8 @@ func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
}
func (s *Suite) TestMagicDNSRootDomains172(c *check.C) {
prefixes := []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
prefixes := []netaddr.IPPrefix{
netaddr.MustParseIPPrefix("172.16.0.0/16"),
}
domains := generateMagicDNSRootDomains(prefixes)
@@ -75,8 +75,8 @@ func (s *Suite) TestMagicDNSRootDomains172(c *check.C) {
// Happens when netmask is a multiple of 4 bits (sounds likely).
func (s *Suite) TestMagicDNSRootDomainsIPv6Single(c *check.C) {
prefixes := []netip.Prefix{
netip.MustParsePrefix("fd7a:115c:a1e0::/48"),
prefixes := []netaddr.IPPrefix{
netaddr.MustParseIPPrefix("fd7a:115c:a1e0::/48"),
}
domains := generateMagicDNSRootDomains(prefixes)
@@ -89,8 +89,8 @@ func (s *Suite) TestMagicDNSRootDomainsIPv6Single(c *check.C) {
}
func (s *Suite) TestMagicDNSRootDomainsIPv6SingleMultiple(c *check.C) {
prefixes := []netip.Prefix{
netip.MustParsePrefix("fd7a:115c:a1e0::/50"),
prefixes := []netaddr.IPPrefix{
netaddr.MustParseIPPrefix("fd7a:115c:a1e0::/50"),
}
domains := generateMagicDNSRootDomains(prefixes)
@@ -165,7 +165,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
NamespaceID: namespaceShared1.ID,
Namespace: *namespaceShared1,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.1")},
AuthKeyID: uint(preAuthKeyInShared1.ID),
}
app.db.Save(machineInShared1)
@@ -182,7 +182,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
NamespaceID: namespaceShared2.ID,
Namespace: *namespaceShared2,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.2")},
AuthKeyID: uint(preAuthKeyInShared2.ID),
}
app.db.Save(machineInShared2)
@@ -199,7 +199,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
NamespaceID: namespaceShared3.ID,
Namespace: *namespaceShared3,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.3")},
AuthKeyID: uint(preAuthKeyInShared3.ID),
}
app.db.Save(machineInShared3)
@@ -216,7 +216,7 @@ func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
NamespaceID: namespaceShared1.ID,
Namespace: *namespaceShared1,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.4")},
AuthKeyID: uint(PreAuthKey2InShared1.ID),
}
app.db.Save(machine2InShared1)
@@ -308,7 +308,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
NamespaceID: namespaceShared1.ID,
Namespace: *namespaceShared1,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.1")},
AuthKeyID: uint(preAuthKeyInShared1.ID),
}
app.db.Save(machineInShared1)
@@ -325,7 +325,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
NamespaceID: namespaceShared2.ID,
Namespace: *namespaceShared2,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.2")},
AuthKeyID: uint(preAuthKeyInShared2.ID),
}
app.db.Save(machineInShared2)
@@ -342,7 +342,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
NamespaceID: namespaceShared3.ID,
Namespace: *namespaceShared3,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.3")},
AuthKeyID: uint(preAuthKeyInShared3.ID),
}
app.db.Save(machineInShared3)
@@ -359,7 +359,7 @@ func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
NamespaceID: namespaceShared1.ID,
Namespace: *namespaceShared1,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.4")},
AuthKeyID: uint(preAuthKey2InShared1.ID),
}
app.db.Save(machine2InShared1)

View File

@@ -36,7 +36,7 @@ ACLs could be written either on [huJSON](https://github.com/tailscale/hujson)
or YAML. Check the [test ACLs](../tests/acls) for further information.
When registering the servers we will need to add the flag
`--advertise-tags=tag:<tag1>,tag:<tag2>`, and the user (namespace) that is
`--advertised-tags=tag:<tag1>,tag:<tag2>`, and the user (namespace) that is
registering the server should be allowed to do it. Since anyone can add tags to
a server they can register, the check of the tags is done on headscale server
and only valid tags are applied. A tag is valid if the namespace that is

View File

@@ -1,19 +0,0 @@
# Connecting an Android client
## Goal
This documentation has the goal of showing how a user can use the official Android [Tailscale](https://tailscale.com) client with `headscale`.
## Installation
Install the official Tailscale Android client from the [Google Play Store](https://play.google.com/store/apps/details?id=com.tailscale.ipn) or [F-Droid](https://f-droid.org/packages/com.tailscale.ipn/).
Ensure that the installed version is at least 1.30.0, as that is the first release to support custom URLs.
## Configuring the headscale URL
After opening the app, the kebab menu icon (three dots) on the top bar on the right must be repeatedly opened and closed until the _Change server_ option appears in the menu. This is where you can enter your headscale URL.
A screen recording of this process can be seen in the `tailscale-android` PR which implemented this functionality: <https://github.com/tailscale/tailscale-android/pull/55>
After saving and restarting the app, selecting the regular _Sign in_ option (non-SSO) should open up the headscale authentication page.

View File

@@ -0,0 +1,32 @@
# Build docker from scratch
The Dockerfiles included in the repository are using the [buildx plugin](https://docs.docker.com/buildx/working-with-buildx/). This plugin is includes in docker newer than Docker-ce CLI 19.03.2. The plugin is used to be able to build different container arches. Building the Dockerfiles without buildx is not possible.
# Build native
To build the container on the native arch you can just use:
```
$ sudo docker buildx build -t headscale:custom-arch .
```
For example: This will build a amd64(x86_64) container if your hostsystem is amd64(x86_64). Or a arm64 container on a arm64 hostsystem (raspberry pi4).
# Build cross platform
To build a arm64 container on a amd64 hostsystem you could use:
```
$ sudo docker buildx build --platform linux/arm64 -t headscale:custom-arm64 .
```
**Import: Currently arm32 build are not supported as there is a problem with a library used by headscale. Hopefully this will be fixed soon.**
# Build multiple arches
To build multiple archres you could use:
```
$ sudo docker buildx create --use
$ sudo docker buildx build --platform linux/amd64,linux/arm64 .
```

Binary file not shown.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 49 KiB

View File

@@ -53,9 +53,6 @@ server_url: http://your-host-name:8080 # Change to your hostname or host IP
metrics_listen_addr: 0.0.0.0:9090
# The default /var/lib/headscale path is not writable in the container
private_key_path: /etc/headscale/private.key
# The default /var/lib/headscale path is not writable in the container
noise:
private_key_path: /var/lib/headscale/noise_private.key
# The default /var/lib/headscale path is not writable in the container
db_path: /etc/headscale/db.sqlite
```
@@ -66,6 +63,7 @@ db_path: /etc/headscale/db.sqlite
docker run \
--name headscale \
--detach \
--rm \
--volume $(pwd)/config:/etc/headscale/ \
--publish 127.0.0.1:8080:8080 \
--publish 127.0.0.1:9090:9090 \

View File

@@ -17,7 +17,7 @@ describing how to make `headscale` run properly in a server environment.
```shell
# Install prerequistes
# 1. go v1.19+: headscale newer than 0.17 needs go 1.19+ to compile
# 1. go v1.18+: headscale newer than 0.15 needs go 1.18+ to compile
# 2. gmake: Makefile in the headscale repo is written in GNU make syntax
pkg_add -D snap go
pkg_add gmake
@@ -46,7 +46,7 @@ cp headscale /usr/local/sbin
```shell
# Install prerequistes
# 1. go v1.19+: headscale newer than 0.17 needs go 1.19+ to compile
# 1. go v1.18+: headscale newer than 0.15 needs go 1.18+ to compile
# 2. gmake: Makefile in the headscale repo is written in GNU make syntax
git clone https://github.com/juanfont/headscale.git

14
flake.lock generated
View File

@@ -2,11 +2,11 @@
"nodes": {
"flake-utils": {
"locked": {
"lastModified": 1659877975,
"narHash": "sha256-zllb8aq3YO3h8B/U0/J1WBgAL8EX5yWf5pMj3G0NAmc=",
"lastModified": 1653893745,
"narHash": "sha256-0jntwV3Z8//YwuOjzhV2sgJJPt+HY6KhU7VZUL0fKZQ=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "c0e246b9b83f637f4681389ecabcb2681b4f3af0",
"rev": "1ed9fb1935d260de5fe1c2f7ee0ebaae17ed2fa1",
"type": "github"
},
"original": {
@@ -17,16 +17,16 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1662019588,
"narHash": "sha256-oPEjHKGGVbBXqwwL+UjsveJzghWiWV0n9ogo1X6l4cw=",
"lastModified": 1654847188,
"narHash": "sha256-MC+eP7XOGE1LAswOPqdcGoUqY9mEQ3ZaaxamVTbc0hM=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "2da64a81275b68fdad38af669afeda43d401e94b",
"rev": "8b66e3f2ebcc644b78cec9d6f152192f4e7d322f",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"ref": "nixos-22.05",
"repo": "nixpkgs",
"type": "github"
}

View File

@@ -2,7 +2,7 @@
description = "headscale - Open Source Tailscale Control server";
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
nixpkgs.url = "github:NixOS/nixpkgs/nixos-22.05";
flake-utils.url = "github:numtide/flake-utils";
};
@@ -17,14 +17,14 @@
in
rec {
headscale =
pkgs.buildGo119Module rec {
pkgs.buildGo118Module rec {
pname = "headscale";
version = headscaleVersion;
src = pkgs.lib.cleanSource self;
# When updating go.mod or go.sum, a new sha will need to be calculated,
# update this if you have a mismatch after doing a change to thos files.
vendorSha256 = "sha256-kc8EU+TkwRlsKM2+ljm/88aWe5h2QMgd/ZGPSgdd9QQ=";
vendorSha256 = "sha256-2o78hsi0B9U5NOcYXRqkBmg34p71J/R8FibXsgwEcSo=";
ldflags = [ "-s" "-w" "-X github.com/juanfont/headscale/cmd/headscale/cli.Version=v${version}" ];
};
@@ -95,7 +95,7 @@
overlays = [ self.overlay ];
inherit system;
};
buildDeps = with pkgs; [ git go_1_19 gnumake ];
buildDeps = with pkgs; [ git go_1_18 gnumake ];
devDeps = with pkgs;
buildDeps ++ [
golangci-lint

99
go.mod
View File

@@ -1,74 +1,66 @@
module github.com/juanfont/headscale
go 1.19
go 1.18
require (
github.com/AlecAivazis/survey/v2 v2.3.5
github.com/AlecAivazis/survey/v2 v2.3.4
github.com/ccding/go-stun/stun v0.0.0-20200514191101-4dc67bcdb029
github.com/coreos/go-oidc/v3 v3.3.0
github.com/coreos/go-oidc/v3 v3.1.0
github.com/deckarep/golang-set/v2 v2.1.0
github.com/efekarakus/termcolor v1.0.1
github.com/glebarez/sqlite v1.4.6
github.com/glebarez/sqlite v1.4.3
github.com/gofrs/uuid v4.2.0+incompatible
github.com/gorilla/mux v1.8.0
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/grpc-ecosystem/grpc-gateway/v2 v2.11.3
github.com/klauspost/compress v1.15.9
github.com/oauth2-proxy/mockoidc v0.0.0-20220308204021-b9169deeb282
github.com/grpc-ecosystem/grpc-gateway/v2 v2.10.0
github.com/klauspost/compress v1.15.4
github.com/ory/dockertest/v3 v3.9.1
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/philip-bui/grpc-zerolog v1.0.1
github.com/prometheus/client_golang v1.13.0
github.com/prometheus/common v0.37.0
github.com/pterm/pterm v0.12.45
github.com/puzpuzpuz/xsync v1.4.3
github.com/rs/zerolog v1.28.0
github.com/spf13/cobra v1.5.0
github.com/spf13/viper v1.12.0
github.com/stretchr/testify v1.8.0
github.com/tailscale/hujson v0.0.0-20220630195928-54599719472f
github.com/prometheus/client_golang v1.12.1
github.com/prometheus/common v0.32.1
github.com/pterm/pterm v0.12.41
github.com/puzpuzpuz/xsync v1.2.1
github.com/rs/zerolog v1.26.1
github.com/spf13/cobra v1.4.0
github.com/spf13/viper v1.11.0
github.com/stretchr/testify v1.7.1
github.com/tailscale/hujson v0.0.0-20220506202205-92b4b88a9e17
github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e
go4.org/netipx v0.0.0-20220812043211-3cc044ffd68d
golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b
golang.org/x/oauth2 v0.0.0-20220822191816-0ebed06d0094
golang.org/x/sync v0.0.0-20220819030929-7fc1605a5dde
google.golang.org/genproto v0.0.0-20220902135211-223410557253
google.golang.org/grpc v1.49.0
google.golang.org/protobuf v1.28.1
golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f
golang.org/x/oauth2 v0.0.0-20220411215720-9780585627b5
golang.org/x/sync v0.0.0-20220513210516-0976fa681c29
google.golang.org/genproto v0.0.0-20220422154200-b37d22cd5731
google.golang.org/grpc v1.46.0
google.golang.org/protobuf v1.28.0
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c
gopkg.in/yaml.v2 v2.4.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/postgres v1.3.9
gorm.io/gorm v1.23.8
tailscale.com v1.30.0
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b
gorm.io/driver/postgres v1.3.5
gorm.io/gorm v1.23.4
inet.af/netaddr v0.0.0-20211027220019-c74959edd3b6
tailscale.com v1.26.0
)
require (
atomicgo.dev/cursor v0.1.1 // indirect
atomicgo.dev/keyboard v0.2.8 // indirect
filippo.io/edwards25519 v1.0.0-rc.1 // indirect
github.com/Azure/go-ansiterm v0.0.0-20170929234023-d6e3b3328b78 // indirect
github.com/Microsoft/go-winio v0.5.2 // indirect
github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect
github.com/akutz/memconn v0.1.0 // indirect
github.com/alexbrainman/sspi v0.0.0-20210105120005-909beea2cc74 // indirect
github.com/atomicgo/cursor v0.0.1 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cenkalti/backoff/v4 v4.1.3 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/containerd/console v1.0.3 // indirect
github.com/containerd/continuity v0.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/docker/cli v20.10.16+incompatible // indirect
github.com/docker/docker v20.10.16+incompatible // indirect
github.com/docker/go-connections v0.4.0 // indirect
github.com/docker/go-units v0.4.0 // indirect
github.com/fsnotify/fsnotify v1.5.4 // indirect
github.com/fxamacker/cbor/v2 v2.4.0 // indirect
github.com/glebarez/go-sqlite v1.17.3 // indirect
github.com/fsnotify/fsnotify v1.5.1 // indirect
github.com/glebarez/go-sqlite v1.16.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/google/go-cmp v0.5.8 // indirect
github.com/google/go-github v17.0.0+incompatible // indirect
@@ -78,17 +70,16 @@ require (
github.com/gookit/color v1.5.0 // indirect
github.com/hashicorp/go-version v1.4.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/hdevalence/ed25519consensus v0.0.0-20220222234857-c00d1f31bab3 // indirect
github.com/imdario/mergo v0.3.12 // indirect
github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/jackc/chunkreader/v2 v2.0.1 // indirect
github.com/jackc/pgconn v1.12.1 // indirect
github.com/jackc/pgconn v1.12.0 // indirect
github.com/jackc/pgio v1.0.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgproto3/v2 v2.3.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect
github.com/jackc/pgtype v1.11.0 // indirect
github.com/jackc/pgx/v4 v4.16.1 // indirect
github.com/jackc/pgx/v4 v4.16.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.4 // indirect
github.com/josharian/native v1.0.0 // indirect
@@ -96,7 +87,6 @@ require (
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
github.com/kr/pretty v0.3.0 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/lithammer/fuzzysearch v1.1.5 // indirect
github.com/magiconair/properties v1.8.6 // indirect
github.com/mattn/go-colorable v0.1.12 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect
@@ -106,43 +96,44 @@ require (
github.com/mdlayher/socket v0.2.3 // indirect
github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b // indirect
github.com/mitchellh/go-ps v1.0.0 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/mitchellh/mapstructure v1.4.3 // indirect
github.com/moby/term v0.0.0-20201216013528-df9cb8a40635 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.0.3-0.20220114050600-8b9d41f48198 // indirect
github.com/opencontainers/runc v1.1.2 // indirect
github.com/pelletier/go-toml v1.9.5 // indirect
github.com/pelletier/go-toml/v2 v2.0.1 // indirect
github.com/pelletier/go-toml v1.9.4 // indirect
github.com/pelletier/go-toml/v2 v2.0.0-beta.8 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.2.0 // indirect
github.com/prometheus/procfs v0.8.0 // indirect
github.com/prometheus/procfs v0.7.3 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/rogpeppe/go-internal v1.8.1-0.20211023094830-115ce09fd6b4 // indirect
github.com/sirupsen/logrus v1.8.1 // indirect
github.com/spf13/afero v1.8.2 // indirect
github.com/spf13/cast v1.5.0 // indirect
github.com/spf13/cast v1.4.1 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/subosito/gotenv v1.3.0 // indirect
github.com/x448/float16 v0.8.4 // indirect
github.com/subosito/gotenv v1.2.0 // indirect
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
github.com/xeipuuv/gojsonschema v1.2.0 // indirect
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 // indirect
go4.org/intern v0.0.0-20211027215823-ae77deb06f29 // indirect
go4.org/mem v0.0.0-20210711025021-927187094b94 // indirect
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20211027215541-db492cf91b37 // indirect
golang.org/x/net v0.0.0-20220516155154-20f960328961 // indirect
golang.org/x/sys v0.0.0-20220513210249-45d2b4557a2a // indirect
golang.org/x/term v0.0.0-20220411215600-e5f449aeb171 // indirect
golang.org/x/text v0.3.7 // indirect
golang.org/x/time v0.0.0-20211116232009-f0f3c7e86c11 // indirect
golang.zx2c4.com/wireguard/windows v0.4.10 // indirect
google.golang.org/appengine v1.6.7 // indirect
gopkg.in/ini.v1 v1.66.4 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
modernc.org/libc v1.16.8 // indirect
gopkg.in/square/go-jose.v2 v2.5.1 // indirect
modernc.org/libc v1.14.12 // indirect
modernc.org/mathutil v1.4.1 // indirect
modernc.org/memory v1.1.1 // indirect
modernc.org/sqlite v1.17.3 // indirect
nhooyr.io/websocket v1.8.7 // indirect
modernc.org/memory v1.0.7 // indirect
modernc.org/sqlite v1.16.0 // indirect
)

601
go.sum

File diff suppressed because it is too large Load Diff

View File

@@ -159,7 +159,7 @@ func (api headscaleV1APIServer) RegisterMachine(
) (*v1.RegisterMachineResponse, error) {
log.Trace().
Str("namespace", request.GetNamespace()).
Str("node_key", request.GetKey()).
Str("machine_key", request.GetKey()).
Msg("Registering machine")
machine, err := api.h.RegisterMachineFromAuthCallback(
@@ -199,7 +199,7 @@ func (api headscaleV1APIServer) SetTags(
err := validateTag(tag)
if err != nil {
return &v1.SetTagsResponse{
Machine: nil,
Machine: nil,
}, status.Error(codes.InvalidArgument, err.Error())
}
}

View File

@@ -1,4 +1,5 @@
//go:build integration_cli
//go:build integration
// +build integration
package headscale
@@ -49,7 +50,7 @@ func (s *IntegrationCLITestSuite) SetupTest() {
}
headscaleBuildOptions := &dockertest.BuildOptions{
Dockerfile: "Dockerfile",
Dockerfile: "Dockerfile.tmp-integration",
ContextDir: ".",
}
@@ -69,31 +70,24 @@ func (s *IntegrationCLITestSuite) SetupTest() {
err = s.pool.RemoveContainerByName(headscaleHostname)
if err != nil {
s.FailNow(
fmt.Sprintf(
"Could not remove existing container before building test: %s",
err,
),
"",
)
s.FailNow(fmt.Sprintf("Could not remove existing container before building test: %s", err), "")
}
fmt.Println("Creating headscale container for CLI tests")
fmt.Println("Creating headscale container")
if pheadscale, err := s.pool.BuildAndRunWithBuildOptions(headscaleBuildOptions, headscaleOptions, DockerRestartPolicy); err == nil {
s.headscale = *pheadscale
} else {
s.FailNow(fmt.Sprintf("Could not start headscale container: %s", err), "")
}
fmt.Println("Created headscale container for CLI tests")
fmt.Println("Created headscale container")
fmt.Println("Waiting for headscale to be ready for CLI tests")
fmt.Println("Waiting for headscale to be ready")
hostEndpoint := fmt.Sprintf("localhost:%s", s.headscale.GetPort("8080/tcp"))
if err := s.pool.Retry(func() error {
url := fmt.Sprintf("http://%s/health", hostEndpoint)
resp, err := http.Get(url)
if err != nil {
fmt.Printf("headscale for CLI test is not ready: %s\n", err)
return err
}
if resp.StatusCode != http.StatusOK {
@@ -108,7 +102,7 @@ func (s *IntegrationCLITestSuite) SetupTest() {
// https://github.com/stretchr/testify/issues/849
return // fmt.Errorf("Could not connect to headscale: %s", err)
}
fmt.Println("headscale container is ready for CLI tests")
fmt.Println("headscale container is ready")
}
func (s *IntegrationCLITestSuite) TearDownTest() {
@@ -1739,8 +1733,6 @@ func (s *IntegrationCLITestSuite) TestLoadConfigFromCommand() {
assert.Nil(s.T(), err)
altConfig, err := os.ReadFile("integration_test/etc/alt-config.dump.gold.yaml")
assert.Nil(s.T(), err)
altEnvConfig, err := os.ReadFile("integration_test/etc/alt-env-config.dump.gold.yaml")
assert.Nil(s.T(), err)
_, err = ExecuteCommand(
&s.headscale,
@@ -1773,40 +1765,4 @@ func (s *IntegrationCLITestSuite) TestLoadConfigFromCommand() {
assert.Nil(s.T(), err)
assert.YAMLEq(s.T(), string(altConfig), string(altDumpConfig))
_, err = ExecuteCommand(
&s.headscale,
[]string{
"headscale",
"dumpConfig",
},
[]string{
"HEADSCALE_CONFIG=/etc/headscale/alt-env-config.yaml",
},
)
assert.Nil(s.T(), err)
altEnvDumpConfig, err := os.ReadFile("integration_test/etc/config.dump.yaml")
assert.Nil(s.T(), err)
assert.YAMLEq(s.T(), string(altEnvConfig), string(altEnvDumpConfig))
_, err = ExecuteCommand(
&s.headscale,
[]string{
"headscale",
"-c",
"/etc/headscale/alt-config.yaml",
"dumpConfig",
},
[]string{
"HEADSCALE_CONFIG=/etc/headscale/alt-env-config.yaml",
},
)
assert.Nil(s.T(), err)
altDumpConfig, err = os.ReadFile("integration_test/etc/config.dump.yaml")
assert.Nil(s.T(), err)
assert.YAMLEq(s.T(), string(altConfig), string(altDumpConfig))
}

View File

@@ -1,4 +1,5 @@
//go:build integration
// +build integration
package headscale
@@ -7,7 +8,6 @@ import (
"encoding/json"
"errors"
"fmt"
"net/netip"
"os"
"strconv"
"strings"
@@ -16,25 +16,23 @@ import (
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"inet.af/netaddr"
)
const (
headscaleHostname = "headscale-derp"
DOCKER_EXECUTE_TIMEOUT = 10 * time.Second
)
var (
errEnvVarEmpty = errors.New("getenv: environment variable empty")
IpPrefix4 = netip.MustParsePrefix("100.64.0.0/10")
IpPrefix6 = netip.MustParsePrefix("fd7a:115c:a1e0::/48")
IpPrefix4 = netaddr.MustParseIPPrefix("100.64.0.0/10")
IpPrefix6 = netaddr.MustParseIPPrefix("fd7a:115c:a1e0::/48")
tailscaleVersions = []string{
// "head",
// "unstable",
"1.30.0",
"1.28.0",
"1.26.2",
"head",
"unstable",
"1.26.0",
"1.24.2",
"1.22.2",
"1.20.4",
@@ -195,8 +193,8 @@ func getDockerBuildOptions(version string) *dockertest.BuildOptions {
func getIPs(
tailscales map[string]dockertest.Resource,
) (map[string][]netip.Addr, error) {
ips := make(map[string][]netip.Addr)
) (map[string][]netaddr.IP, error) {
ips := make(map[string][]netaddr.IP)
for hostname, tailscale := range tailscales {
command := []string{"tailscale", "ip"}
@@ -214,7 +212,7 @@ func getIPs(
if len(address) < 1 {
continue
}
ip, err := netip.ParseAddr(address)
ip, err := netaddr.ParseIP(address)
if err != nil {
return nil, err
}
@@ -228,6 +226,7 @@ func getIPs(
func getDNSNames(
headscale *dockertest.Resource,
) ([]string, error) {
listAllResult, err := ExecuteCommand(
headscale,
[]string{
@@ -261,6 +260,7 @@ func getDNSNames(
func getMagicFQDN(
headscale *dockertest.Resource,
) ([]string, error) {
listAllResult, err := ExecuteCommand(
headscale,
[]string{
@@ -285,11 +285,7 @@ func getMagicFQDN(
hostnames := make([]string, len(listAll))
for index := range listAll {
hostnames[index] = fmt.Sprintf(
"%s.%s.headscale.net",
listAll[index].GetGivenName(),
listAll[index].GetNamespace().GetName(),
)
hostnames[index] = fmt.Sprintf("%s.%s.headscale.net", listAll[index].GetGivenName(), listAll[index].GetNamespace().GetName())
}
return hostnames, nil

View File

@@ -1,4 +1,4 @@
//go:build integration_derp
//go:build integration
package headscale
@@ -8,6 +8,7 @@ import (
"crypto/tls"
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net/http"
"os"
@@ -27,8 +28,9 @@ import (
)
const (
namespaceName = "derpnamespace"
totalContainers = 3
headscaleHostname = "headscale-derp"
namespaceName = "derpnamespace"
totalContainers = 3
)
type IntegrationDERPTestSuite struct {
@@ -102,7 +104,7 @@ func (s *IntegrationDERPTestSuite) SetupSuite() {
}
headscaleBuildOptions := &dockertest.BuildOptions{
Dockerfile: "Dockerfile",
Dockerfile: "Dockerfile.tmp-integration",
ContextDir: ".",
}
@@ -129,24 +131,18 @@ func (s *IntegrationDERPTestSuite) SetupSuite() {
err = s.pool.RemoveContainerByName(headscaleHostname)
if err != nil {
s.FailNow(
fmt.Sprintf(
"Could not remove existing container before building test: %s",
err,
),
"",
)
s.FailNow(fmt.Sprintf("Could not remove existing container before building test: %s", err), "")
}
log.Println("Creating headscale container for DERP integration tests")
log.Println("Creating headscale container")
if pheadscale, err := s.pool.BuildAndRunWithBuildOptions(headscaleBuildOptions, headscaleOptions, DockerRestartPolicy); err == nil {
s.headscale = *pheadscale
} else {
s.FailNow(fmt.Sprintf("Could not start headscale container: %s", err), "")
}
log.Println("Created headscale container for embedded DERP tests")
log.Println("Created headscale container to test DERP")
log.Println("Creating tailscale containers for embedded DERP tests")
log.Println("Creating tailscale containers")
for i := 0; i < totalContainers; i++ {
version := tailscaleVersions[i%len(tailscaleVersions)]
@@ -158,7 +154,7 @@ func (s *IntegrationDERPTestSuite) SetupSuite() {
s.tailscales[hostname] = *container
}
log.Println("Waiting for headscale to be ready for embedded DERP tests")
log.Println("Waiting for headscale to be ready")
hostEndpoint := fmt.Sprintf("localhost:%s", s.headscale.GetPort("8443/tcp"))
if err := s.pool.Retry(func() error {
@@ -168,7 +164,6 @@ func (s *IntegrationDERPTestSuite) SetupSuite() {
client := &http.Client{Transport: insecureTransport}
resp, err := client.Get(url)
if err != nil {
fmt.Printf("headscale for embedded DERP tests is not ready: %s\n", err)
return err
}
@@ -184,7 +179,7 @@ func (s *IntegrationDERPTestSuite) SetupSuite() {
// https://github.com/stretchr/testify/issues/849
return // fmt.Errorf("Could not connect to headscale: %s", err)
}
log.Println("headscale container is ready for embedded DERP tests")
log.Println("headscale container is ready")
log.Printf("Creating headscale namespace: %s\n", namespaceName)
result, err := ExecuteCommand(
@@ -367,7 +362,7 @@ func (s *IntegrationDERPTestSuite) saveLog(
log.Printf("Saving logs for %s to %s\n", resource.Container.Name, basePath)
err = os.WriteFile(
err = ioutil.WriteFile(
path.Join(basePath, resource.Container.Name+".stdout.log"),
[]byte(stdout.String()),
0o644,
@@ -376,7 +371,7 @@ func (s *IntegrationDERPTestSuite) saveLog(
return err
}
err = os.WriteFile(
err = ioutil.WriteFile(
path.Join(basePath, resource.Container.Name+".stderr.log"),
[]byte(stdout.String()),
0o644,

View File

@@ -1,4 +1,5 @@
//go:build integration_general
//go:build integration
// +build integration
package headscale
@@ -8,9 +9,9 @@ import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"log"
"net/http"
"net/netip"
"os"
"path"
"strings"
@@ -23,6 +24,7 @@ import (
"github.com/ory/dockertest/v3/docker"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"inet.af/netaddr"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/ipn/ipnstate"
)
@@ -51,7 +53,7 @@ func TestIntegrationTestSuite(t *testing.T) {
s.namespaces = map[string]TestNamespace{
"thisspace": {
count: 10,
count: 5,
tailscales: make(map[string]dockertest.Resource),
},
"otherspace": {
@@ -123,7 +125,7 @@ func (s *IntegrationTestSuite) saveLog(
log.Printf("Saving logs for %s to %s\n", resource.Container.Name, basePath)
err = os.WriteFile(
err = ioutil.WriteFile(
path.Join(basePath, resource.Container.Name+".stdout.log"),
[]byte(stdout.String()),
0o644,
@@ -132,7 +134,7 @@ func (s *IntegrationTestSuite) saveLog(
return err
}
err = os.WriteFile(
err = ioutil.WriteFile(
path.Join(basePath, resource.Container.Name+".stderr.log"),
[]byte(stdout.String()),
0o644,
@@ -226,7 +228,7 @@ func (s *IntegrationTestSuite) SetupSuite() {
}
headscaleBuildOptions := &dockertest.BuildOptions{
Dockerfile: "Dockerfile",
Dockerfile: "Dockerfile.tmp-integration",
ContextDir: ".",
}
@@ -246,24 +248,18 @@ func (s *IntegrationTestSuite) SetupSuite() {
err = s.pool.RemoveContainerByName(headscaleHostname)
if err != nil {
s.FailNow(
fmt.Sprintf(
"Could not remove existing container before building test: %s",
err,
),
"",
)
s.FailNow(fmt.Sprintf("Could not remove existing container before building test: %s", err), "")
}
log.Println("Creating headscale container for core integration tests")
log.Println("Creating headscale container")
if pheadscale, err := s.pool.BuildAndRunWithBuildOptions(headscaleBuildOptions, headscaleOptions, DockerRestartPolicy); err == nil {
s.headscale = *pheadscale
} else {
s.FailNow(fmt.Sprintf("Could not start headscale container for core integration tests: %s", err), "")
s.FailNow(fmt.Sprintf("Could not start headscale container: %s", err), "")
}
log.Println("Created headscale container for core integration tests")
log.Println("Created headscale container")
log.Println("Creating tailscale containers for core integration tests")
log.Println("Creating tailscale containers")
for namespace, scales := range s.namespaces {
for i := 0; i < scales.count; i++ {
version := tailscaleVersions[i%len(tailscaleVersions)]
@@ -277,7 +273,7 @@ func (s *IntegrationTestSuite) SetupSuite() {
}
}
log.Println("Waiting for headscale to be ready for core integration tests")
log.Println("Waiting for headscale to be ready")
hostEndpoint := fmt.Sprintf("localhost:%s", s.headscale.GetPort("8080/tcp"))
if err := s.pool.Retry(func() error {
@@ -285,7 +281,6 @@ func (s *IntegrationTestSuite) SetupSuite() {
resp, err := http.Get(url)
if err != nil {
fmt.Printf("headscale for core integration test is not ready: %s\n", err)
return err
}
@@ -301,7 +296,7 @@ func (s *IntegrationTestSuite) SetupSuite() {
// https://github.com/stretchr/testify/issues/849
return // fmt.Errorf("Could not connect to headscale: %s", err)
}
log.Println("headscale container is ready for core integration tests")
log.Println("headscale container is ready")
for namespace, scales := range s.namespaces {
log.Printf("Creating headscale namespace: %s\n", namespace)
@@ -477,8 +472,8 @@ func (s *IntegrationTestSuite) TestGetIpAddresses() {
// }
// }
func getIPsfromIPNstate(status ipnstate.Status) []netip.Addr {
ips := make([]netip.Addr, 0)
func getIPsfromIPNstate(status ipnstate.Status) []netaddr.IP {
ips := make([]netaddr.IP, 0)
for _, peer := range status.Peer {
ips = append(ips, peer.TailscaleIPs...)
@@ -562,25 +557,13 @@ func (s *IntegrationTestSuite) TestTailDrop() {
if peername == hostname {
continue
}
var ip4 netip.Addr
for _, ip := range ips[peername] {
if ip.Is4() {
ip4 = ip
break
}
}
if ip4.IsUnspecified() {
panic("no ipv4 address found")
}
s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) {
command := []string{
"tailscale", "file", "cp",
fmt.Sprintf("/tmp/file_from_%s", hostname),
fmt.Sprintf("%s:", ip4),
fmt.Sprintf("%s:", ips[peername][1]),
}
err := retry(10, 1*time.Second, func() error {
retry(10, 1*time.Second, func() error {
log.Printf(
"Sending file from %s to %s\n",
hostname,
@@ -594,7 +577,6 @@ func (s *IntegrationTestSuite) TestTailDrop() {
)
return err
})
assert.Nil(t, err)
})
}
@@ -696,18 +678,6 @@ func (s *IntegrationTestSuite) TestMagicDNS() {
ips, err := getIPs(scales.tailscales)
assert.Nil(s.T(), err)
retry := func(times int, sleepInverval time.Duration, doWork func() (string, error)) (result string, err error) {
for attempts := 0; attempts < times; attempts++ {
result, err = doWork()
if err == nil {
return
}
time.Sleep(sleepInverval)
}
return
}
for hostname, tailscale := range scales.tailscales {
for _, peername := range hostnames {
if strings.Contains(peername, hostname) {
@@ -718,20 +688,17 @@ func (s *IntegrationTestSuite) TestMagicDNS() {
command := []string{
"tailscale", "ip", peername,
}
result, err := retry(10, 1*time.Second, func() (string, error) {
log.Printf(
"Resolving name %s from %s\n",
peername,
hostname,
)
result, err := ExecuteCommand(
&tailscale,
command,
[]string{},
)
return result, err
})
log.Printf(
"Resolving name %s from %s\n",
peername,
hostname,
)
result, err := ExecuteCommand(
&tailscale,
command,
[]string{},
)
assert.Nil(t, err)
log.Printf("Result for %s: %s\n", hostname, result)
@@ -748,8 +715,8 @@ func (s *IntegrationTestSuite) TestMagicDNS() {
func getAPIURLs(
tailscales map[string]dockertest.Resource,
) (map[netip.Addr]string, error) {
fts := make(map[netip.Addr]string)
) (map[netaddr.IP]string, error) {
fts := make(map[netaddr.IP]string)
for _, tailscale := range tailscales {
command := []string{
"curl",
@@ -773,11 +740,11 @@ func getAPIURLs(
for _, ft := range pft {
n := ft.Node
for _, a := range n.Addresses { // just add all the addresses
if _, ok := fts[a.Addr()]; !ok {
if _, ok := fts[a.IP()]; !ok {
if ft.PeerAPIURL == "" {
return nil, errors.New("api url is empty")
}
fts[a.Addr()] = ft.PeerAPIURL
fts[a.IP()] = ft.PeerAPIURL
}
}
}

View File

@@ -18,7 +18,6 @@ dns_config:
domains: []
magic_dns: true
nameservers:
- 127.0.0.11
- 1.1.1.1
ephemeral_node_inactivity_timeout: 30m
node_update_check_interval: 10s
@@ -28,9 +27,7 @@ ip_prefixes:
- fd7a:115c:a1e0::/48
- 100.64.0.0/10
listen_addr: 0.0.0.0:18080
log:
level: disabled
format: text
log_level: disabled
logtail:
enabled: false
metrics_listen_addr: 127.0.0.1:19090
@@ -41,8 +38,6 @@ oidc:
- email
strip_email_domain: true
private_key_path: private.key
noise:
private_key_path: noise_private.key
server_url: http://headscale:18080
tls_client_auth_mode: relaxed
tls_letsencrypt_cache_dir: /var/www/.cache

View File

@@ -1,5 +1,4 @@
log:
level: trace
log_level: trace
acl_policy_path: ""
db_type: sqlite3
ephemeral_node_inactivity_timeout: 30m
@@ -12,12 +11,9 @@ dns_config:
magic_dns: true
domains: []
nameservers:
- 127.0.0.11
- 1.1.1.1
db_path: /tmp/integration_test_db.sqlite3
private_key_path: private.key
noise:
private_key_path: noise_private.key
listen_addr: 0.0.0.0:18080
metrics_listen_addr: 127.0.0.1:19090
server_url: http://headscale:18080

View File

@@ -1,51 +0,0 @@
acl_policy_path: ""
cli:
insecure: false
timeout: 5s
db_path: /tmp/integration_test_db.sqlite3
db_type: sqlite3
derp:
auto_update_enabled: false
server:
enabled: false
stun:
enabled: true
update_frequency: 1m
urls:
- https://controlplane.tailscale.com/derpmap/default
dns_config:
base_domain: headscale.net
domains: []
magic_dns: true
nameservers:
- 1.1.1.1
ephemeral_node_inactivity_timeout: 30m
node_update_check_interval: 30s
grpc_allow_insecure: false
grpc_listen_addr: :50443
ip_prefixes:
- fd7a:115c:a1e0::/48
- 100.64.0.0/10
listen_addr: 0.0.0.0:18080
log:
level: disabled
format: text
logtail:
enabled: false
metrics_listen_addr: 127.0.0.1:19090
oidc:
scope:
- openid
- profile
- email
strip_email_domain: true
private_key_path: private.key
noise:
private_key_path: noise_private.key
server_url: http://headscale:18080
tls_client_auth_mode: relaxed
tls_letsencrypt_cache_dir: /var/www/.cache
tls_letsencrypt_challenge_type: HTTP-01
unix_socket: /var/run/headscale.sock
unix_socket_permission: "0o770"
randomize_client_port: false

View File

@@ -1,28 +0,0 @@
log:
level: trace
acl_policy_path: ""
db_type: sqlite3
ephemeral_node_inactivity_timeout: 30m
node_update_check_interval: 30s
ip_prefixes:
- fd7a:115c:a1e0::/48
- 100.64.0.0/10
dns_config:
base_domain: headscale.net
magic_dns: true
domains: []
nameservers:
- 1.1.1.1
db_path: /tmp/integration_test_db.sqlite3
private_key_path: private.key
noise:
private_key_path: noise_private.key
listen_addr: 0.0.0.0:18080
metrics_listen_addr: 127.0.0.1:19090
server_url: http://headscale:18080
derp:
urls:
- https://controlplane.tailscale.com/derpmap/default
auto_update_enabled: false
update_frequency: 1m

View File

@@ -18,7 +18,6 @@ dns_config:
domains: []
magic_dns: true
nameservers:
- 127.0.0.11
- 1.1.1.1
ephemeral_node_inactivity_timeout: 30m
node_update_check_interval: 10s
@@ -28,9 +27,7 @@ ip_prefixes:
- fd7a:115c:a1e0::/48
- 100.64.0.0/10
listen_addr: 0.0.0.0:8080
log:
format: text
level: disabled
log_level: disabled
logtail:
enabled: false
metrics_listen_addr: 127.0.0.1:9090
@@ -41,8 +38,6 @@ oidc:
- email
strip_email_domain: true
private_key_path: private.key
noise:
private_key_path: noise_private.key
server_url: http://headscale:8080
tls_client_auth_mode: relaxed
tls_letsencrypt_cache_dir: /var/www/.cache

View File

@@ -1,5 +1,4 @@
log:
level: trace
log_level: trace
acl_policy_path: ""
db_type: sqlite3
ephemeral_node_inactivity_timeout: 30m
@@ -12,12 +11,9 @@ dns_config:
magic_dns: true
domains: []
nameservers:
- 127.0.0.11
- 1.1.1.1
db_path: /tmp/integration_test_db.sqlite3
private_key_path: private.key
noise:
private_key_path: noise_private.key
listen_addr: 0.0.0.0:8080
metrics_listen_addr: 127.0.0.1:9090
server_url: http://headscale:8080

View File

@@ -14,10 +14,8 @@ dns_config:
- 1.1.1.1
db_path: /tmp/integration_test_db.sqlite3
private_key_path: private.key
noise:
private_key_path: noise_private.key
listen_addr: 0.0.0.0:443
server_url: https://headscale:443
listen_addr: 0.0.0.0:8443
server_url: https://headscale:8443
tls_cert_path: "/etc/headscale/tls/server.crt"
tls_key_path: "/etc/headscale/tls/server.key"
tls_client_auth_mode: disabled

View File

@@ -4,7 +4,6 @@ import (
"database/sql/driver"
"errors"
"fmt"
"net/netip"
"sort"
"strconv"
"strings"
@@ -13,35 +12,28 @@ import (
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/rs/zerolog/log"
"google.golang.org/protobuf/types/known/timestamppb"
"inet.af/netaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
const (
ErrMachineNotFound = Error("machine not found")
ErrMachineRouteIsNotAvailable = Error("route is not available on machine")
ErrMachineAddressesInvalid = Error("failed to parse machine addresses")
ErrMachineNotFoundRegistrationCache = Error(
errMachineNotFound = Error("machine not found")
errMachineRouteIsNotAvailable = Error("route is not available on machine")
errMachineAddressesInvalid = Error("failed to parse machine addresses")
errMachineNotFoundRegistrationCache = Error(
"machine not found in registration cache",
)
ErrCouldNotConvertMachineInterface = Error("failed to convert machine interface")
ErrHostnameTooLong = Error("Hostname too long")
ErrDifferentRegisteredNamespace = Error(
"machine was previously registered with a different namespace",
)
MachineGivenNameHashLength = 8
MachineGivenNameTrimSize = 2
errCouldNotConvertMachineInterface = Error("failed to convert machine interface")
errHostnameTooLong = Error("Hostname too long")
MachineGivenNameHashLength = 8
MachineGivenNameTrimSize = 2
)
const (
maxHostnameLength = 255
)
var (
ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0")
ExitRouteV6 = netip.MustParsePrefix("::/0")
)
// Machine is a Headscale client.
type Machine struct {
ID uint64 `gorm:"primary_key"`
@@ -90,7 +82,7 @@ type (
MachinesP []*Machine
)
type MachineAddresses []netip.Addr
type MachineAddresses []netaddr.IP
func (ma MachineAddresses) ToStringSlice() []string {
strSlice := make([]string, 0, len(ma))
@@ -110,7 +102,7 @@ func (ma *MachineAddresses) Scan(destination interface{}) error {
if len(addr) < 1 {
continue
}
parsed, err := netip.ParseAddr(addr)
parsed, err := netaddr.ParseIP(addr)
if err != nil {
return err
}
@@ -120,7 +112,7 @@ func (ma *MachineAddresses) Scan(destination interface{}) error {
return nil
default:
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
return fmt.Errorf("%w: unexpected data type %T", errMachineAddressesInvalid, destination)
}
}
@@ -252,8 +244,8 @@ func (h *Headscale) ListPeers(machine *Machine) (Machines, error) {
Msg("Finding direct peers")
machines := Machines{}
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where("node_key <> ?",
machine.NodeKey).Find(&machines).Error; err != nil {
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where("machine_key <> ?",
machine.MachineKey).Find(&machines).Error; err != nil {
log.Error().Err(err).Msg("Error accessing db")
return Machines{}, err
@@ -345,7 +337,7 @@ func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error)
}
}
return nil, ErrMachineNotFound
return nil, errMachineNotFound
}
// GetMachineByID finds a Machine by ID and returns the Machine struct.
@@ -358,7 +350,7 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
return &m, nil
}
// GetMachineByMachineKey finds a Machine by its MachineKey and returns the Machine struct.
// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct.
func (h *Headscale) GetMachineByMachineKey(
machineKey key.MachinePublic,
) (*Machine, error) {
@@ -370,32 +362,6 @@ func (h *Headscale) GetMachineByMachineKey(
return &m, nil
}
// GetMachineByNodeKey finds a Machine by its current NodeKey.
func (h *Headscale) GetMachineByNodeKey(
nodeKey key.NodePublic,
) (*Machine, error) {
machine := Machine{}
if result := h.db.Preload("Namespace").First(&machine, "node_key = ?",
NodePublicKeyStripPrefix(nodeKey)); result.Error != nil {
return nil, result.Error
}
return &machine, nil
}
// GetMachineByAnyNodeKey finds a Machine by its current NodeKey or the old one, and returns the Machine struct.
func (h *Headscale) GetMachineByAnyNodeKey(
nodeKey key.NodePublic, oldNodeKey key.NodePublic,
) (*Machine, error) {
machine := Machine{}
if result := h.db.Preload("Namespace").First(&machine, "node_key = ? OR node_key = ?",
NodePublicKeyStripPrefix(nodeKey), NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil {
return nil, result.Error
}
return &machine, nil
}
// UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database
// and updates it with the latest data from the database.
func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error {
@@ -418,7 +384,7 @@ func (h *Headscale) SetTags(machine *Machine, tags []string) error {
if err := h.UpdateACLRules(); err != nil && !errors.Is(err, errEmptyPolicy) {
return err
}
h.setLastStateChangeToNow()
h.setLastStateChangeToNow(machine.Namespace.Name)
if err := h.db.Save(machine).Error; err != nil {
return fmt.Errorf("failed to update tags for machine in the database: %w", err)
@@ -432,7 +398,7 @@ func (h *Headscale) ExpireMachine(machine *Machine) error {
now := time.Now()
machine.Expiry = &now
h.setLastStateChangeToNow()
h.setLastStateChangeToNow(machine.Namespace.Name)
if err := h.db.Save(machine).Error; err != nil {
return fmt.Errorf("failed to expire machine in the database: %w", err)
@@ -459,7 +425,7 @@ func (h *Headscale) RenameMachine(machine *Machine, newName string) error {
}
machine.GivenName = newName
h.setLastStateChangeToNow()
h.setLastStateChangeToNow(machine.Namespace.Name)
if err := h.db.Save(machine).Error; err != nil {
return fmt.Errorf("failed to rename machine in the database: %w", err)
@@ -475,7 +441,7 @@ func (h *Headscale) RefreshMachine(machine *Machine, expiry time.Time) error {
machine.LastSuccessfulUpdate = &now
machine.Expiry = &expiry
h.setLastStateChangeToNow()
h.setLastStateChangeToNow(machine.Namespace.Name)
if err := h.db.Save(machine).Error; err != nil {
return fmt.Errorf(
@@ -608,14 +574,11 @@ func (machine Machine) toNode(
}
var machineKey key.MachinePublic
// MachineKey is only used in the legacy protocol
if machine.MachineKey != "" {
err = machineKey.UnmarshalText(
[]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)),
)
if err != nil {
return nil, fmt.Errorf("failed to parse machine public key: %w", err)
}
err = machineKey.UnmarshalText(
[]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)),
)
if err != nil {
return nil, fmt.Errorf("failed to parse machine public key: %w", err)
}
var discoKey key.DiscoPublic
@@ -630,32 +593,20 @@ func (machine Machine) toNode(
discoKey = key.DiscoPublic{}
}
addrs := []netip.Prefix{}
addrs := []netaddr.IPPrefix{}
for _, machineAddress := range machine.IPAddresses {
ip := netip.PrefixFrom(machineAddress, machineAddress.BitLen())
ip := netaddr.IPPrefixFrom(machineAddress, machineAddress.BitLen())
addrs = append(addrs, ip)
}
allowedIPs := append(
[]netip.Prefix{},
[]netaddr.IPPrefix{},
addrs...) // we append the node own IP, as it is required by the clients
allowedIPs = append(allowedIPs, machine.EnabledRoutes...)
// TODO(kradalby): This is kind of a hack where we say that
// all the announced routes (except exit), is presented as primary
// routes. This might be problematic if two nodes expose the same route.
// This was added to address an issue where subnet routers stopped working
// when we only populated AllowedIPs.
primaryRoutes := []netip.Prefix{}
if len(machine.EnabledRoutes) > 0 {
for _, route := range machine.EnabledRoutes {
if route == ExitRouteV4 || route == ExitRouteV6 {
continue
}
primaryRoutes = append(primaryRoutes, route)
}
// TODO(kradalby): Needs investigation, We probably dont need this condition
// now that we dont have shared nodes
if includeRoutes {
allowedIPs = append(allowedIPs, machine.EnabledRoutes...)
}
var derp string
@@ -684,7 +635,7 @@ func (machine Machine) toNode(
return nil, fmt.Errorf(
"hostname %q is too long it cannot except 255 ASCII chars: %w",
hostname,
ErrHostnameTooLong,
errHostnameTooLong,
)
}
} else {
@@ -702,17 +653,16 @@ func (machine Machine) toNode(
StableID: tailcfg.StableNodeID(
strconv.FormatUint(machine.ID, Base10),
), // in headscale, unlike tailcontrol server, IDs are permanent
Name: hostname,
User: tailcfg.UserID(machine.NamespaceID),
Key: nodeKey,
KeyExpiry: keyExpiry,
Machine: machineKey,
DiscoKey: discoKey,
Addresses: addrs,
AllowedIPs: allowedIPs,
PrimaryRoutes: primaryRoutes,
Endpoints: machine.Endpoints,
DERP: derp,
Name: hostname,
User: tailcfg.UserID(machine.NamespaceID),
Key: nodeKey,
KeyExpiry: keyExpiry,
Machine: machineKey,
DiscoKey: discoKey,
Addresses: addrs,
AllowedIPs: allowedIPs,
Endpoints: machine.Endpoints,
DERP: derp,
Online: &online,
Hostinfo: hostInfo.View(),
@@ -812,11 +762,11 @@ func getTags(
}
func (h *Headscale) RegisterMachineFromAuthCallback(
nodeKeyStr string,
machineKeyStr string,
namespaceName string,
registrationMethod string,
) (*Machine, error) {
if machineInterface, ok := h.registrationCache.Get(nodeKeyStr); ok {
if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok {
if registrationMachine, ok := machineInterface.(Machine); ok {
namespace, err := h.GetNamespace(namespaceName)
if err != nil {
@@ -826,12 +776,6 @@ func (h *Headscale) RegisterMachineFromAuthCallback(
)
}
// Registration of expired machine with different namespace
if registrationMachine.ID != 0 &&
registrationMachine.NamespaceID != namespace.ID {
return nil, ErrDifferentRegisteredNamespace
}
registrationMachine.NamespaceID = namespace.ID
registrationMachine.RegisterMethod = registrationMethod
@@ -839,17 +783,13 @@ func (h *Headscale) RegisterMachineFromAuthCallback(
registrationMachine,
)
if err == nil {
h.registrationCache.Delete(nodeKeyStr)
}
return machine, err
} else {
return nil, ErrCouldNotConvertMachineInterface
return nil, errCouldNotConvertMachineInterface
}
}
return nil, ErrMachineNotFoundRegistrationCache
return nil, errMachineNotFoundRegistrationCache
}
// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey.
@@ -894,16 +834,16 @@ func (h *Headscale) RegisterMachine(machine Machine,
return &machine, nil
}
func (machine *Machine) GetAdvertisedRoutes() []netip.Prefix {
func (machine *Machine) GetAdvertisedRoutes() []netaddr.IPPrefix {
return machine.HostInfo.RoutableIPs
}
func (machine *Machine) GetEnabledRoutes() []netip.Prefix {
func (machine *Machine) GetEnabledRoutes() []netaddr.IPPrefix {
return machine.EnabledRoutes
}
func (machine *Machine) IsRoutesEnabled(routeStr string) bool {
route, err := netip.ParsePrefix(routeStr)
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return false
}
@@ -922,9 +862,9 @@ func (machine *Machine) IsRoutesEnabled(routeStr string) bool {
// EnableNodeRoute enables new routes based on a list of new routes. It will _replace_ the
// previous list of routes.
func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error {
newRoutes := make([]netip.Prefix, len(routeStrs))
newRoutes := make([]netaddr.IPPrefix, len(routeStrs))
for index, routeStr := range routeStrs {
route, err := netip.ParsePrefix(routeStr)
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return err
}
@@ -937,7 +877,7 @@ func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error {
return fmt.Errorf(
"route (%s) is not available on node %s: %w",
machine.Hostname,
newRoute, ErrMachineRouteIsNotAvailable,
newRoute, errMachineRouteIsNotAvailable,
)
}
}

View File

@@ -2,7 +2,6 @@ package headscale
import (
"fmt"
"net/netip"
"reflect"
"strconv"
"strings"
@@ -10,8 +9,8 @@ import (
"time"
"gopkg.in/check.v1"
"inet.af/netaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
func (s *Suite) TestGetMachine(c *check.C) {
@@ -66,63 +65,6 @@ func (s *Suite) TestGetMachineByID(c *check.C) {
c.Assert(err, check.IsNil)
}
func (s *Suite) TestGetMachineByNodeKey(c *check.C) {
namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = app.GetMachineByID(0)
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()),
DiscoKey: "faa",
Hostname: "testmachine",
NamespaceID: namespace.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
app.db.Save(&machine)
_, err = app.GetMachineByNodeKey(nodeKey.Public())
c.Assert(err, check.IsNil)
}
func (s *Suite) TestGetMachineByAnyNodeKey(c *check.C) {
namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = app.GetMachineByID(0)
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
oldNodeKey := key.NewNode()
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()),
DiscoKey: "faa",
Hostname: "testmachine",
NamespaceID: namespace.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
app.db.Save(&machine)
_, err = app.GetMachineByAnyNodeKey(nodeKey.Public(), oldNodeKey.Public())
c.Assert(err, check.IsNil)
}
func (s *Suite) TestDeleteMachine(c *check.C) {
namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
@@ -229,7 +171,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
NodeKey: "bar" + strconv.Itoa(index),
DiscoKey: "faa" + strconv.Itoa(index),
IPAddresses: MachineAddresses{
netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))),
netaddr.MustParseIP(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))),
},
Hostname: "testmachine" + strconv.Itoa(index),
NamespaceID: stor[index%2].namespace.ID,
@@ -243,19 +185,11 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
Groups: map[string][]string{
"group:test": {"admin"},
},
Hosts: map[string]netip.Prefix{},
Hosts: map[string]netaddr.IPPrefix{},
TagOwners: map[string][]string{},
ACLs: []ACL{
{
Action: "accept",
Sources: []string{"admin"},
Destinations: []string{"*:*"},
},
{
Action: "accept",
Sources: []string{"test"},
Destinations: []string{"test:*"},
},
{Action: "accept", Sources: []string{"admin"}, Destinations: []string{"*:*"}},
{Action: "accept", Sources: []string{"test"}, Destinations: []string{"test:*"}},
},
Tests: []ACLTest{},
}
@@ -326,9 +260,9 @@ func (s *Suite) TestExpireMachine(c *check.C) {
}
func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) {
input := MachineAddresses([]netip.Addr{
netip.MustParseAddr("192.0.2.1"),
netip.MustParseAddr("2001:db8::1"),
input := MachineAddresses([]netaddr.IP{
netaddr.MustParseIP("192.0.2.1"),
netaddr.MustParseIP("2001:db8::1"),
})
serialized, err := input.Value()
c.Assert(err, check.IsNil)
@@ -540,6 +474,7 @@ func Test_getTags(t *testing.T) {
}
}
// nolint
func Test_getFilteredByACLPeers(t *testing.T) {
type args struct {
machines []Machine
@@ -558,21 +493,21 @@ func Test_getFilteredByACLPeers(t *testing.T) {
{
ID: 1,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.1"),
netaddr.MustParseIP("100.64.0.1"),
},
Namespace: Namespace{Name: "joe"},
},
{
ID: 2,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.2"),
netaddr.MustParseIP("100.64.0.2"),
},
Namespace: Namespace{Name: "marc"},
},
{
ID: 3,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.3"),
netaddr.MustParseIP("100.64.0.3"),
},
Namespace: Namespace{Name: "mickael"},
},
@@ -587,19 +522,19 @@ func Test_getFilteredByACLPeers(t *testing.T) {
},
machine: &Machine{ // current machine
ID: 1,
IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.1")},
IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")},
Namespace: Namespace{Name: "joe"},
},
},
want: Machines{
{
ID: 2,
IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.2")},
IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")},
Namespace: Namespace{Name: "marc"},
},
{
ID: 3,
IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.3")},
IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")},
Namespace: Namespace{Name: "mickael"},
},
},
@@ -611,21 +546,21 @@ func Test_getFilteredByACLPeers(t *testing.T) {
{
ID: 1,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.1"),
netaddr.MustParseIP("100.64.0.1"),
},
Namespace: Namespace{Name: "joe"},
},
{
ID: 2,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.2"),
netaddr.MustParseIP("100.64.0.2"),
},
Namespace: Namespace{Name: "marc"},
},
{
ID: 3,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.3"),
netaddr.MustParseIP("100.64.0.3"),
},
Namespace: Namespace{Name: "mickael"},
},
@@ -640,14 +575,14 @@ func Test_getFilteredByACLPeers(t *testing.T) {
},
machine: &Machine{ // current machine
ID: 1,
IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.1")},
IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.1")},
Namespace: Namespace{Name: "joe"},
},
},
want: Machines{
{
ID: 2,
IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.2")},
IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")},
Namespace: Namespace{Name: "marc"},
},
},
@@ -659,21 +594,21 @@ func Test_getFilteredByACLPeers(t *testing.T) {
{
ID: 1,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.1"),
netaddr.MustParseIP("100.64.0.1"),
},
Namespace: Namespace{Name: "joe"},
},
{
ID: 2,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.2"),
netaddr.MustParseIP("100.64.0.2"),
},
Namespace: Namespace{Name: "marc"},
},
{
ID: 3,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.3"),
netaddr.MustParseIP("100.64.0.3"),
},
Namespace: Namespace{Name: "mickael"},
},
@@ -688,14 +623,14 @@ func Test_getFilteredByACLPeers(t *testing.T) {
},
machine: &Machine{ // current machine
ID: 2,
IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.2")},
IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")},
Namespace: Namespace{Name: "marc"},
},
},
want: Machines{
{
ID: 3,
IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.3")},
IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")},
Namespace: Namespace{Name: "mickael"},
},
},
@@ -707,21 +642,21 @@ func Test_getFilteredByACLPeers(t *testing.T) {
{
ID: 1,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.1"),
netaddr.MustParseIP("100.64.0.1"),
},
Namespace: Namespace{Name: "joe"},
},
{
ID: 2,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.2"),
netaddr.MustParseIP("100.64.0.2"),
},
Namespace: Namespace{Name: "marc"},
},
{
ID: 3,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.3"),
netaddr.MustParseIP("100.64.0.3"),
},
Namespace: Namespace{Name: "mickael"},
},
@@ -737,7 +672,7 @@ func Test_getFilteredByACLPeers(t *testing.T) {
machine: &Machine{ // current machine
ID: 1,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.1"),
netaddr.MustParseIP("100.64.0.1"),
},
Namespace: Namespace{Name: "joe"},
},
@@ -746,7 +681,7 @@ func Test_getFilteredByACLPeers(t *testing.T) {
{
ID: 2,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.2"),
netaddr.MustParseIP("100.64.0.2"),
},
Namespace: Namespace{Name: "marc"},
},
@@ -759,21 +694,21 @@ func Test_getFilteredByACLPeers(t *testing.T) {
{
ID: 1,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.1"),
netaddr.MustParseIP("100.64.0.1"),
},
Namespace: Namespace{Name: "joe"},
},
{
ID: 2,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.2"),
netaddr.MustParseIP("100.64.0.2"),
},
Namespace: Namespace{Name: "marc"},
},
{
ID: 3,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.3"),
netaddr.MustParseIP("100.64.0.3"),
},
Namespace: Namespace{Name: "mickael"},
},
@@ -789,7 +724,7 @@ func Test_getFilteredByACLPeers(t *testing.T) {
machine: &Machine{ // current machine
ID: 2,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.2"),
netaddr.MustParseIP("100.64.0.2"),
},
Namespace: Namespace{Name: "marc"},
},
@@ -798,14 +733,14 @@ func Test_getFilteredByACLPeers(t *testing.T) {
{
ID: 1,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.1"),
netaddr.MustParseIP("100.64.0.1"),
},
Namespace: Namespace{Name: "joe"},
},
{
ID: 3,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.3"),
netaddr.MustParseIP("100.64.0.3"),
},
Namespace: Namespace{Name: "mickael"},
},
@@ -818,21 +753,21 @@ func Test_getFilteredByACLPeers(t *testing.T) {
{
ID: 1,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.1"),
netaddr.MustParseIP("100.64.0.1"),
},
Namespace: Namespace{Name: "joe"},
},
{
ID: 2,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.2"),
netaddr.MustParseIP("100.64.0.2"),
},
Namespace: Namespace{Name: "marc"},
},
{
ID: 3,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.3"),
netaddr.MustParseIP("100.64.0.3"),
},
Namespace: Namespace{Name: "mickael"},
},
@@ -847,7 +782,7 @@ func Test_getFilteredByACLPeers(t *testing.T) {
},
machine: &Machine{ // current machine
ID: 2,
IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.2")},
IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")},
Namespace: Namespace{Name: "marc"},
},
},
@@ -855,13 +790,13 @@ func Test_getFilteredByACLPeers(t *testing.T) {
{
ID: 1,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.1"),
netaddr.MustParseIP("100.64.0.1"),
},
Namespace: Namespace{Name: "joe"},
},
{
ID: 3,
IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.3")},
IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.3")},
Namespace: Namespace{Name: "mickael"},
},
},
@@ -873,21 +808,21 @@ func Test_getFilteredByACLPeers(t *testing.T) {
{
ID: 1,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.1"),
netaddr.MustParseIP("100.64.0.1"),
},
Namespace: Namespace{Name: "joe"},
},
{
ID: 2,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.2"),
netaddr.MustParseIP("100.64.0.2"),
},
Namespace: Namespace{Name: "marc"},
},
{
ID: 3,
IPAddresses: MachineAddresses{
netip.MustParseAddr("100.64.0.3"),
netaddr.MustParseIP("100.64.0.3"),
},
Namespace: Namespace{Name: "mickael"},
},
@@ -896,7 +831,7 @@ func Test_getFilteredByACLPeers(t *testing.T) {
},
machine: &Machine{ // current machine
ID: 2,
IPAddresses: MachineAddresses{netip.MustParseAddr("100.64.0.2")},
IPAddresses: MachineAddresses{netaddr.MustParseIP("100.64.0.2")},
Namespace: Namespace{Name: "marc"},
},
},

View File

@@ -16,10 +16,10 @@ import (
)
const (
ErrNamespaceExists = Error("Namespace already exists")
ErrNamespaceNotFound = Error("Namespace not found")
ErrNamespaceNotEmptyOfNodes = Error("Namespace not empty: node(s) found")
ErrInvalidNamespaceName = Error("Invalid namespace name")
errNamespaceExists = Error("Namespace already exists")
errNamespaceNotFound = Error("Namespace not found")
errNamespaceNotEmptyOfNodes = Error("Namespace not empty: node(s) found")
errInvalidNamespaceName = Error("Invalid namespace name")
)
const (
@@ -47,7 +47,7 @@ func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
}
namespace := Namespace{}
if err := h.db.Where("name = ?", name).First(&namespace).Error; err == nil {
return nil, ErrNamespaceExists
return nil, errNamespaceExists
}
namespace.Name = name
if err := h.db.Create(&namespace).Error; err != nil {
@@ -67,7 +67,7 @@ func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
func (h *Headscale) DestroyNamespace(name string) error {
namespace, err := h.GetNamespace(name)
if err != nil {
return ErrNamespaceNotFound
return errNamespaceNotFound
}
machines, err := h.ListMachinesInNamespace(name)
@@ -75,7 +75,7 @@ func (h *Headscale) DestroyNamespace(name string) error {
return err
}
if len(machines) > 0 {
return ErrNamespaceNotEmptyOfNodes
return errNamespaceNotEmptyOfNodes
}
keys, err := h.ListPreAuthKeys(name)
@@ -110,9 +110,9 @@ func (h *Headscale) RenameNamespace(oldName, newName string) error {
}
_, err = h.GetNamespace(newName)
if err == nil {
return ErrNamespaceExists
return errNamespaceExists
}
if !errors.Is(err, ErrNamespaceNotFound) {
if !errors.Is(err, errNamespaceNotFound) {
return err
}
@@ -132,7 +132,7 @@ func (h *Headscale) GetNamespace(name string) (*Namespace, error) {
result.Error,
gorm.ErrRecordNotFound,
) {
return nil, ErrNamespaceNotFound
return nil, errNamespaceNotFound
}
return &namespace, nil
@@ -272,7 +272,7 @@ func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) {
return "", fmt.Errorf(
"label %v is more than 63 chars: %w",
elt,
ErrInvalidNamespaceName,
errInvalidNamespaceName,
)
}
}
@@ -285,21 +285,21 @@ func CheckForFQDNRules(name string) error {
return fmt.Errorf(
"DNS segment must not be over 63 chars. %v doesn't comply with this rule: %w",
name,
ErrInvalidNamespaceName,
errInvalidNamespaceName,
)
}
if strings.ToLower(name) != name {
return fmt.Errorf(
"DNS segment should be lowercase. %v doesn't comply with this rule: %w",
name,
ErrInvalidNamespaceName,
errInvalidNamespaceName,
)
}
if invalidCharsInNamespaceRegex.MatchString(name) {
return fmt.Errorf(
"DNS segment should only be composed of lowercase ASCII letters numbers, hyphen and dots. %v doesn't comply with theses rules: %w",
name,
ErrInvalidNamespaceName,
errInvalidNamespaceName,
)
}

View File

@@ -1,11 +1,11 @@
package headscale
import (
"net/netip"
"testing"
"gopkg.in/check.v1"
"gorm.io/gorm"
"inet.af/netaddr"
)
func (s *Suite) TestCreateAndDestroyNamespace(c *check.C) {
@@ -26,7 +26,7 @@ func (s *Suite) TestCreateAndDestroyNamespace(c *check.C) {
func (s *Suite) TestDestroyNamespaceErrors(c *check.C) {
err := app.DestroyNamespace("test")
c.Assert(err, check.Equals, ErrNamespaceNotFound)
c.Assert(err, check.Equals, errNamespaceNotFound)
namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
@@ -60,7 +60,7 @@ func (s *Suite) TestDestroyNamespaceErrors(c *check.C) {
app.db.Save(&machine)
err = app.DestroyNamespace("test")
c.Assert(err, check.Equals, ErrNamespaceNotEmptyOfNodes)
c.Assert(err, check.Equals, errNamespaceNotEmptyOfNodes)
}
func (s *Suite) TestRenameNamespace(c *check.C) {
@@ -76,20 +76,20 @@ func (s *Suite) TestRenameNamespace(c *check.C) {
c.Assert(err, check.IsNil)
_, err = app.GetNamespace("test")
c.Assert(err, check.Equals, ErrNamespaceNotFound)
c.Assert(err, check.Equals, errNamespaceNotFound)
_, err = app.GetNamespace("test-renamed")
c.Assert(err, check.IsNil)
err = app.RenameNamespace("test-does-not-exit", "test")
c.Assert(err, check.Equals, ErrNamespaceNotFound)
c.Assert(err, check.Equals, errNamespaceNotFound)
namespaceTest2, err := app.CreateNamespace("test2")
c.Assert(err, check.IsNil)
c.Assert(namespaceTest2.Name, check.Equals, "test2")
err = app.RenameNamespace("test2", "test-renamed")
c.Assert(err, check.Equals, ErrNamespaceExists)
c.Assert(err, check.Equals, errNamespaceExists)
}
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
@@ -146,7 +146,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
NamespaceID: namespaceShared1.ID,
Namespace: *namespaceShared1,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.1")},
AuthKeyID: uint(preAuthKeyShared1.ID),
}
app.db.Save(machineInShared1)
@@ -163,7 +163,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
NamespaceID: namespaceShared2.ID,
Namespace: *namespaceShared2,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.2")},
AuthKeyID: uint(preAuthKeyShared2.ID),
}
app.db.Save(machineInShared2)
@@ -180,7 +180,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
NamespaceID: namespaceShared3.ID,
Namespace: *namespaceShared3,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.3")},
AuthKeyID: uint(preAuthKeyShared3.ID),
}
app.db.Save(machineInShared3)
@@ -197,7 +197,7 @@ func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
NamespaceID: namespaceShared1.ID,
Namespace: *namespaceShared1,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
IPAddresses: []netaddr.IP{netaddr.MustParseIP("100.64.0.4")},
AuthKeyID: uint(preAuthKey2Shared1.ID),
}
app.db.Save(machine2InShared1)
@@ -402,7 +402,7 @@ func (s *Suite) TestSetMachineNamespace(c *check.C) {
c.Assert(machine.Namespace.Name, check.Equals, newNamespace.Name)
err = app.SetMachineNamespace(&machine, "non-existing-namespace")
c.Assert(err, check.Equals, ErrNamespaceNotFound)
c.Assert(err, check.Equals, errNamespaceNotFound)
err = app.SetMachineNamespace(&machine, newNamespace.Name)
c.Assert(err, check.IsNil)

View File

@@ -1,55 +0,0 @@
package headscale
import (
"net/http"
"github.com/rs/zerolog/log"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"tailscale.com/control/controlhttp"
"tailscale.com/net/netutil"
)
const (
// ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade.
ts2021UpgradePath = "/ts2021"
)
// NoiseUpgradeHandler is to upgrade the connection and hijack the net.Conn
// in order to use the Noise-based TS2021 protocol. Listens in /ts2021.
func (h *Headscale) NoiseUpgradeHandler(
writer http.ResponseWriter,
req *http.Request,
) {
log.Trace().Caller().Msgf("Noise upgrade handler for client %s", req.RemoteAddr)
upgrade := req.Header.Get("Upgrade")
if upgrade == "" {
// This probably means that the user is running Headscale behind an
// improperly configured reverse proxy. TS2021 requires WebSockets to
// be passed to Headscale. Let's give them a hint.
log.Warn().
Caller().
Msg("No Upgrade header in TS2021 request. If headscale is behind a reverse proxy, make sure it is configured to pass WebSockets through.")
http.Error(writer, "Internal error", http.StatusInternalServerError)
return
}
noiseConn, err := controlhttp.AcceptHTTP(req.Context(), writer, req, *h.noisePrivateKey)
if err != nil {
log.Error().Err(err).Msg("noise upgrade failed")
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
server := http.Server{
ReadTimeout: HTTPReadTimeout,
}
server.Handler = h2c.NewHandler(h.noiseMux, &http2.Server{})
err = server.Serve(netutil.NewOneConnListener(noiseConn, nil))
if err != nil {
log.Info().Err(err).Msg("The HTTP2 server was closed")
}
}

375
oidc.go
View File

@@ -21,13 +21,6 @@ import (
const (
randomByteSize = 16
errEmptyOIDCCallbackParams = Error("empty OIDC callback params")
errNoOIDCIDToken = Error("could not extract ID Token for OIDC callback")
errOIDCAllowedDomains = Error("authenticated principal does not match any allowed domain")
errOIDCAllowedUsers = Error("authenticated principal does not match any allowed user")
errOIDCInvalidMachineState = Error("requested machine state key expired before authorisation completed")
errOIDCNodeKeyMissing = Error("could not get node key from cache")
)
type IDTokenClaims struct {
@@ -68,26 +61,26 @@ func (h *Headscale) initOIDC() error {
}
// RegisterOIDC redirects to the OIDC provider for authentication
// Puts NodeKey in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:nKey.
// Puts machine key in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:mKey.
func (h *Headscale) RegisterOIDC(
writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
nodeKeyStr, ok := vars["nkey"]
if !ok || nodeKeyStr == "" {
machineKeyStr, ok := vars["mkey"]
if !ok || machineKeyStr == "" {
log.Error().
Caller().
Msg("Missing node key in URL")
http.Error(writer, "Missing node key in URL", http.StatusBadRequest)
Msg("Missing machine key in URL")
http.Error(writer, "Missing machine key in URL", http.StatusBadRequest)
return
}
log.Trace().
Caller().
Str("node_key", nodeKeyStr).
Str("machine_key", machineKeyStr).
Msg("Received oidc register call")
randomBlob := make([]byte, randomByteSize)
@@ -102,8 +95,8 @@ func (h *Headscale) RegisterOIDC(
stateStr := hex.EncodeToString(randomBlob)[:32]
// place the node key into the state cache, so it can be retrieved later
h.registrationCache.Set(stateStr, nodeKeyStr, registerCacheExpiration)
// place the machine key into the state cache, so it can be retrieved later
h.registrationCache.Set(stateStr, machineKeyStr, registerCacheExpiration)
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
@@ -135,7 +128,7 @@ var oidcCallbackTemplate = template.Must(
)
// OIDCCallback handles the callback from the OIDC endpoint
// Retrieves the nkey from the state cache and adds the machine to the users email namespace
// Retrieves the mkey from the state cache and adds the machine to the users email namespace
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
// TODO: Add groups information from OIDC tokens into machine HostInfo
// Listens in /oidc/callback.
@@ -143,82 +136,6 @@ func (h *Headscale) OIDCCallback(
writer http.ResponseWriter,
req *http.Request,
) {
code, state, err := validateOIDCCallbackParams(writer, req)
if err != nil {
return
}
rawIDToken, err := h.getIDTokenForOIDCCallback(req.Context(), writer, code, state)
if err != nil {
return
}
idToken, err := h.verifyIDTokenForOIDCCallback(req.Context(), writer, rawIDToken)
if err != nil {
return
}
// TODO: we can use userinfo at some point to grab additional information about the user (groups membership, etc)
// userInfo, err := oidcProvider.UserInfo(context.Background(), oauth2.StaticTokenSource(oauth2Token))
// if err != nil {
// c.String(http.StatusBadRequest, fmt.Sprintf("Failed to retrieve userinfo"))
// return
// }
claims, err := extractIDTokenClaims(writer, idToken)
if err != nil {
return
}
if err := validateOIDCAllowedDomains(writer, h.cfg.OIDC.AllowedDomains, claims); err != nil {
return
}
if err := validateOIDCAllowedUsers(writer, h.cfg.OIDC.AllowedUsers, claims); err != nil {
return
}
nodeKey, machineExists, err := h.validateMachineForOIDCCallback(writer, state, claims)
if err != nil || machineExists {
return
}
namespaceName, err := getNamespaceName(writer, claims, h.cfg.OIDC.StripEmaildomain)
if err != nil {
return
}
// register the machine if it's new
log.Debug().Msg("Registering new machine after successful callback")
namespace, err := h.findOrCreateNewNamespaceForOIDCCallback(writer, namespaceName)
if err != nil {
return
}
if err := h.registerMachineForOIDCCallback(writer, namespace, nodeKey); err != nil {
return
}
content, err := renderOIDCCallbackTemplate(writer, claims)
if err != nil {
return
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
if _, err := writer.Write(content.Bytes()); err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
}
func validateOIDCCallbackParams(
writer http.ResponseWriter,
req *http.Request,
) (string, string, error) {
code := req.URL.Query().Get("code")
state := req.URL.Query().Get("state")
@@ -233,18 +150,10 @@ func validateOIDCCallbackParams(
Msg("Failed to write response")
}
return "", "", errEmptyOIDCCallbackParams
return
}
return code, state, nil
}
func (h *Headscale) getIDTokenForOIDCCallback(
ctx context.Context,
writer http.ResponseWriter,
code, state string,
) (string, error) {
oauth2Token, err := h.oauth2Config.Exchange(ctx, code)
oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code)
if err != nil {
log.Error().
Err(err).
@@ -252,15 +161,15 @@ func (h *Headscale) getIDTokenForOIDCCallback(
Msg("Could not exchange code for token")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Could not exchange code for token"))
if werr != nil {
_, err := writer.Write([]byte("Could not exchange code for token"))
if err != nil {
log.Error().
Caller().
Err(werr).
Err(err).
Msg("Failed to write response")
}
return "", err
return
}
log.Trace().
@@ -281,19 +190,12 @@ func (h *Headscale) getIDTokenForOIDCCallback(
Msg("Failed to write response")
}
return "", errNoOIDCIDToken
return
}
return rawIDToken, nil
}
func (h *Headscale) verifyIDTokenForOIDCCallback(
ctx context.Context,
writer http.ResponseWriter,
rawIDToken string,
) (*oidc.IDToken, error) {
verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID})
idToken, err := verifier.Verify(ctx, rawIDToken)
idToken, err := verifier.Verify(context.Background(), rawIDToken)
if err != nil {
log.Error().
Err(err).
@@ -301,56 +203,48 @@ func (h *Headscale) verifyIDTokenForOIDCCallback(
Msg("failed to verify id token")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Failed to verify id token"))
if werr != nil {
_, err := writer.Write([]byte("Failed to verify id token"))
if err != nil {
log.Error().
Caller().
Err(werr).
Err(err).
Msg("Failed to write response")
}
return nil, err
return
}
return idToken, nil
}
// TODO: we can use userinfo at some point to grab additional information about the user (groups membership, etc)
// userInfo, err := oidcProvider.UserInfo(context.Background(), oauth2.StaticTokenSource(oauth2Token))
// if err != nil {
// c.String(http.StatusBadRequest, fmt.Sprintf("Failed to retrieve userinfo"))
// return
// }
func extractIDTokenClaims(
writer http.ResponseWriter,
idToken *oidc.IDToken,
) (*IDTokenClaims, error) {
// Extract custom claims
var claims IDTokenClaims
if err := idToken.Claims(&claims); err != nil {
if err = idToken.Claims(&claims); err != nil {
log.Error().
Err(err).
Caller().
Msg("Failed to decode id token claims")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Failed to decode id token claims"))
if werr != nil {
_, err := writer.Write([]byte("Failed to decode id token claims"))
if err != nil {
log.Error().
Caller().
Err(werr).
Err(err).
Msg("Failed to write response")
}
return nil, err
return
}
return &claims, nil
}
// validateOIDCAllowedDomains checks that if AllowedDomains is provided,
// that the authenticated principal ends with @<alloweddomain>.
func validateOIDCAllowedDomains(
writer http.ResponseWriter,
allowedDomains []string,
claims *IDTokenClaims,
) error {
if len(allowedDomains) > 0 {
// If AllowedDomains is provided, check that the authenticated principal ends with @<alloweddomain>.
if len(h.cfg.OIDC.AllowedDomains) > 0 {
if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
!IsStringInSlice(allowedDomains, claims.Email[at+1:]) {
!IsStringInSlice(h.cfg.OIDC.AllowedDomains, claims.Email[at+1:]) {
log.Error().Msg("authenticated principal does not match any allowed domain")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
@@ -362,22 +256,13 @@ func validateOIDCAllowedDomains(
Msg("Failed to write response")
}
return errOIDCAllowedDomains
return
}
}
return nil
}
// validateOIDCAllowedUsers checks that if AllowedUsers is provided,
// that the authenticated principal is part of that list.
func validateOIDCAllowedUsers(
writer http.ResponseWriter,
allowedUsers []string,
claims *IDTokenClaims,
) error {
if len(allowedUsers) > 0 &&
!IsStringInSlice(allowedUsers, claims.Email) {
// If AllowedUsers is provided, check that the authenticated princial is part of that list.
if len(h.cfg.OIDC.AllowedUsers) > 0 &&
!IsStringInSlice(h.cfg.OIDC.AllowedUsers, claims.Email) {
log.Error().Msg("authenticated principal does not match any allowed user")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
@@ -389,23 +274,12 @@ func validateOIDCAllowedUsers(
Msg("Failed to write response")
}
return errOIDCAllowedUsers
return
}
return nil
}
// validateMachine retrieves machine information if it exist
// The error is not important, because if it does not
// exist, then this is a new machine and we will move
// on to registration.
func (h *Headscale) validateMachineForOIDCCallback(
writer http.ResponseWriter,
state string,
claims *IDTokenClaims,
) (*key.NodePublic, bool, error) {
// retrieve machinekey from state cache
machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
if !machineKeyFound {
log.Error().
Msg("requested machine state key expired before authorisation completed")
@@ -419,35 +293,21 @@ func (h *Headscale) validateMachineForOIDCCallback(
Msg("Failed to write response")
}
return nil, false, errOIDCInvalidMachineState
return
}
var nodeKey key.NodePublic
nodeKeyFromCache, nodeKeyOK := machineKeyIf.(string)
err := nodeKey.UnmarshalText(
[]byte(NodePublicKeyEnsurePrefix(nodeKeyFromCache)),
machineKeyFromCache, machineKeyOK := machineKeyIf.(string)
var machineKey key.MachinePublic
err = machineKey.UnmarshalText(
[]byte(MachinePublicKeyEnsurePrefix(machineKeyFromCache)),
)
if err != nil {
log.Error().
Msg("could not parse node public key")
Msg("could not parse machine public key")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("could not parse public key"))
if werr != nil {
log.Error().
Caller().
Err(werr).
Msg("Failed to write response")
}
return nil, false, err
}
if !nodeKeyOK {
log.Error().Msg("could not get node key from cache")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("could not get node key from cache"))
_, err := writer.Write([]byte("could not parse public key"))
if err != nil {
log.Error().
Caller().
@@ -455,14 +315,29 @@ func (h *Headscale) validateMachineForOIDCCallback(
Msg("Failed to write response")
}
return nil, false, errOIDCNodeKeyMissing
return
}
if !machineKeyOK {
log.Error().Msg("could not get machine key from cache")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("could not get machine key from cache"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
// retrieve machine information if it exist
// The error is not important, because if it does not
// exist, then this is a new machine and we will move
// on to registration.
machine, _ := h.GetMachineByNodeKey(nodeKey)
machine, _ := h.GetMachineByMachineKey(machineKey)
if machine != nil {
log.Trace().
@@ -476,13 +351,9 @@ func (h *Headscale) validateMachineForOIDCCallback(
Caller().
Err(err).
Msg("Failed to refresh machine")
http.Error(
writer,
"Failed to refresh machine",
http.StatusInternalServerError,
)
http.Error(writer, "Failed to refresh machine", http.StatusInternalServerError)
return nil, true, err
return
}
var content bytes.Buffer
@@ -498,15 +369,15 @@ func (h *Headscale) validateMachineForOIDCCallback(
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("Could not render OIDC callback template"))
if werr != nil {
_, err := writer.Write([]byte("Could not render OIDC callback template"))
if err != nil {
log.Error().
Caller().
Err(werr).
Err(err).
Msg("Failed to write response")
}
return nil, true, err
return
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
@@ -519,45 +390,33 @@ func (h *Headscale) validateMachineForOIDCCallback(
Msg("Failed to write response")
}
return nil, true, nil
return
}
return &nodeKey, false, nil
}
func getNamespaceName(
writer http.ResponseWriter,
claims *IDTokenClaims,
stripEmaildomain bool,
) (string, error) {
namespaceName, err := NormalizeToFQDNRules(
claims.Email,
stripEmaildomain,
h.cfg.OIDC.StripEmaildomain,
)
if err != nil {
log.Error().Err(err).Caller().Msgf("couldn't normalize email")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("couldn't normalize email"))
if werr != nil {
_, err := writer.Write([]byte("couldn't normalize email"))
if err != nil {
log.Error().
Caller().
Err(werr).
Err(err).
Msg("Failed to write response")
}
return "", err
return
}
return namespaceName, nil
}
// register the machine if it's new
log.Debug().Msg("Registering new machine after successful callback")
func (h *Headscale) findOrCreateNewNamespaceForOIDCCallback(
writer http.ResponseWriter,
namespaceName string,
) (*Namespace, error) {
namespace, err := h.GetNamespace(namespaceName)
if errors.Is(err, ErrNamespaceNotFound) {
if errors.Is(err, errNamespaceNotFound) {
namespace, err = h.CreateNamespace(namespaceName)
if err != nil {
@@ -567,15 +426,15 @@ func (h *Headscale) findOrCreateNewNamespaceForOIDCCallback(
Msgf("could not create new namespace '%s'", namespaceName)
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("could not create namespace"))
if werr != nil {
_, err := writer.Write([]byte("could not create namespace"))
if err != nil {
log.Error().
Caller().
Err(werr).
Err(err).
Msg("Failed to write response")
}
return nil, err
return
}
} else if err != nil {
log.Error().
@@ -585,56 +444,42 @@ func (h *Headscale) findOrCreateNewNamespaceForOIDCCallback(
Msg("could not find or create namespace")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("could not find or create namespace"))
if werr != nil {
_, err := writer.Write([]byte("could not find or create namespace"))
if err != nil {
log.Error().
Caller().
Err(werr).
Err(err).
Msg("Failed to write response")
}
return nil, err
return
}
return namespace, nil
}
machineKeyStr := MachinePublicKeyStripPrefix(machineKey)
func (h *Headscale) registerMachineForOIDCCallback(
writer http.ResponseWriter,
namespace *Namespace,
nodeKey *key.NodePublic,
) error {
nodeKeyStr := NodePublicKeyStripPrefix(*nodeKey)
if _, err := h.RegisterMachineFromAuthCallback(
nodeKeyStr,
_, err = h.RegisterMachineFromAuthCallback(
machineKeyStr,
namespace.Name,
RegisterMethodOIDC,
); err != nil {
)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("could not register machine")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("could not register machine"))
if werr != nil {
_, err := writer.Write([]byte("could not register machine"))
if err != nil {
log.Error().
Caller().
Err(werr).
Err(err).
Msg("Failed to write response")
}
return err
return
}
return nil
}
func renderOIDCCallbackTemplate(
writer http.ResponseWriter,
claims *IDTokenClaims,
) (*bytes.Buffer, error) {
var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
User: claims.Email,
@@ -648,16 +493,24 @@ func renderOIDCCallbackTemplate(
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("Could not render OIDC callback template"))
if werr != nil {
_, err := writer.Write([]byte("Could not render OIDC callback template"))
if err != nil {
log.Error().
Caller().
Err(werr).
Err(err).
Msg("Failed to write response")
}
return nil, err
return
}
return &content, nil
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(content.Bytes())
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
}

View File

@@ -325,9 +325,7 @@ func (h *Headscale) ApplePlatformConfig(
default:
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write(
[]byte("Invalid platform, only ios and macos is supported"),
)
_, err := writer.Write([]byte("Invalid platform, only ios and macos is supported"))
if err != nil {
log.Error().
Caller().
@@ -364,8 +362,7 @@ func (h *Headscale) ApplePlatformConfig(
return
}
writer.Header().
Set("Content-Type", "application/x-apple-aspen-config; charset=utf-8")
writer.Header().Set("Content-Type", "application/x-apple-aspen-config; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(content.Bytes())
if err != nil {

View File

@@ -2,12 +2,17 @@ package headscale
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"time"
"github.com/gorilla/mux"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
const (
@@ -18,15 +23,83 @@ type contextKey string
const machineNameContextKey = contextKey("machineName")
// handlePollCommon is the common code for the legacy and Noise protocols to
// managed the poll loop.
func (h *Headscale) handlePollCommon(
// PollNetMapHandler takes care of /machine/:id/map
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
//
// The clients POST stuff like HostInfo and their Endpoints here, but
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) PollNetMapHandler(
writer http.ResponseWriter,
ctx context.Context,
machine *Machine,
mapRequest tailcfg.MapRequest,
isNoise bool,
req *http.Request,
) {
vars := mux.Vars(req)
machineKeyStr, ok := vars["mkey"]
if !ok || machineKeyStr == "" {
log.Error().
Str("handler", "PollNetMap").
Msg("No machine key in request")
http.Error(writer, "No machine key in request", http.StatusBadRequest)
return
}
log.Trace().
Str("handler", "PollNetMap").
Str("id", machineKeyStr).
Msg("PollNetMapHandler called")
body, _ := io.ReadAll(req.Body)
var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr)))
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot parse client key")
http.Error(writer, "Cannot parse client key", http.StatusBadRequest)
return
}
mapRequest := tailcfg.MapRequest{}
err = decode(body, &mapRequest, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot decode message")
http.Error(writer, "Cannot decode message", http.StatusBadRequest)
return
}
machine, err := h.GetMachineByMachineKey(machineKey)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", machineKey.String())
http.Error(writer, "", http.StatusUnauthorized)
return
}
log.Error().
Str("handler", "PollNetMap").
Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String())
http.Error(writer, "", http.StatusInternalServerError)
return
}
log.Trace().
Str("handler", "PollNetMap").
Str("id", machineKeyStr).
Str("machine", machine.Hostname).
Msg("Found machine in database")
machine.Hostname = mapRequest.Hostinfo.Hostname
machine.HostInfo = HostInfo(*mapRequest.Hostinfo)
machine.DiscoKey = DiscoPublicKeyStripPrefix(mapRequest.DiscoKey)
@@ -34,11 +107,11 @@ func (h *Headscale) handlePollCommon(
// update ACLRules with peer informations (to update server tags if necessary)
if h.aclPolicy != nil {
err := h.UpdateACLRules()
err = h.UpdateACLRules()
if err != nil {
log.Error().
Caller().
Bool("noise", isNoise).
Str("func", "handleAuthKey").
Str("machine", machine.Hostname).
Err(err)
}
@@ -60,8 +133,7 @@ func (h *Headscale) handlePollCommon(
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("node_key", machine.NodeKey).
Str("id", machineKeyStr).
Str("machine", machine.Hostname).
Err(err).
Msg("Failed to persist/update machine in the database")
@@ -71,12 +143,11 @@ func (h *Headscale) handlePollCommon(
}
}
mapResp, err := h.getMapResponseData(mapRequest, machine, isNoise)
data, err := h.getMapResponse(machineKey, mapRequest, machine)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("node_key", machine.NodeKey).
Str("id", machineKeyStr).
Str("machine", machine.Hostname).
Err(err).
Msg("Failed to get Map response")
@@ -92,7 +163,7 @@ func (h *Headscale) handlePollCommon(
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
log.Debug().
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("id", machineKeyStr).
Str("machine", machine.Hostname).
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
@@ -102,13 +173,12 @@ func (h *Headscale) handlePollCommon(
if mapRequest.ReadOnly {
log.Info().
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Client is starting up. Probably interested in a DERP map")
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err := writer.Write(mapResp)
_, err := writer.Write(data)
if err != nil {
log.Error().
Caller().
@@ -116,24 +186,20 @@ func (h *Headscale) handlePollCommon(
Msg("Failed to write response")
}
if f, ok := writer.(http.Flusher); ok {
f.Flush()
}
return
}
// There has been an update to _any_ of the nodes that the other nodes would
// need to know about
h.setLastStateChangeToNow()
h.setLastStateChangeToNow(machine.Namespace.Name)
// The request is not ReadOnly, so we need to set up channels for updating
// peers via longpoll
// Only create update channel if it has not been created
log.Trace().
Caller().
Bool("noise", isNoise).
Str("handler", "PollNetMap").
Str("id", machineKeyStr).
Str("machine", machine.Hostname).
Msg("Loading or creating update channel")
@@ -148,12 +214,11 @@ func (h *Headscale) handlePollCommon(
if mapRequest.OmitPeers && !mapRequest.Stream {
log.Info().
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Client sent endpoint update and is ok with a response without peer list")
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err := writer.Write(mapResp)
_, err := writer.Write(data)
if err != nil {
log.Error().
Caller().
@@ -170,7 +235,6 @@ func (h *Headscale) handlePollCommon(
} else if mapRequest.OmitPeers && mapRequest.Stream {
log.Warn().
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Ignoring request, don't know how to handle it")
http.Error(writer, "", http.StatusBadRequest)
@@ -180,59 +244,56 @@ func (h *Headscale) handlePollCommon(
log.Info().
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Client is ready to access the tailnet")
log.Info().
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Sending initial map")
pollDataChan <- mapResp
pollDataChan <- data
log.Info().
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Notifying peers")
updateRequestsFromNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "full-update").
Inc()
updateChan <- struct{}{}
h.pollNetMapStream(
h.PollNetMapStream(
writer,
ctx,
req,
machine,
mapRequest,
machineKey,
pollDataChan,
keepAliveChan,
updateChan,
isNoise,
)
log.Trace().
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("id", machineKeyStr).
Str("machine", machine.Hostname).
Msg("Finished stream, closing PollNetMap session")
}
// pollNetMapStream stream logic for /machine/map,
// ensuring we communicate updates and data to the connected clients.
func (h *Headscale) pollNetMapStream(
// PollNetMapStream takes care of /machine/:id/map
// stream logic, ensuring we communicate updates and data
// to the connected clients.
func (h *Headscale) PollNetMapStream(
writer http.ResponseWriter,
ctxReq context.Context,
req *http.Request,
machine *Machine,
mapRequest tailcfg.MapRequest,
machineKey key.MachinePublic,
pollDataChan chan []byte,
keepAliveChan chan []byte,
updateChan chan struct{},
isNoise bool,
) {
h.pollNetMapStreamWG.Add(1)
defer h.pollNetMapStreamWG.Done()
ctx := context.WithValue(ctxReq, machineNameContextKey, machine.Hostname)
ctx := context.WithValue(req.Context(), machineNameContextKey, machine.Hostname)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
@@ -241,20 +302,18 @@ func (h *Headscale) pollNetMapStream(
ctx,
updateChan,
keepAliveChan,
machineKey,
mapRequest,
machine,
isNoise,
)
log.Trace().
Str("handler", "pollNetMapStream").
Bool("noise", isNoise).
Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname).
Msg("Waiting for data to stream...")
log.Trace().
Str("handler", "pollNetMapStream").
Bool("noise", isNoise).
Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname).
Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan)
@@ -263,7 +322,6 @@ func (h *Headscale) pollNetMapStream(
case data := <-pollDataChan:
log.Trace().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "pollData").
Int("bytes", len(data)).
@@ -272,7 +330,6 @@ func (h *Headscale) pollNetMapStream(
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "pollData").
Err(err).
@@ -286,7 +343,6 @@ func (h *Headscale) pollNetMapStream(
log.Error().
Caller().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "pollData").
Msg("Cannot cast writer to http.Flusher")
@@ -296,7 +352,6 @@ func (h *Headscale) pollNetMapStream(
log.Trace().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "pollData").
Int("bytes", len(data)).
@@ -308,7 +363,6 @@ func (h *Headscale) pollNetMapStream(
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "pollData").
Err(err).
@@ -329,7 +383,6 @@ func (h *Headscale) pollNetMapStream(
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "pollData").
Err(err).
@@ -340,7 +393,6 @@ func (h *Headscale) pollNetMapStream(
log.Trace().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "pollData").
Int("bytes", len(data)).
@@ -357,7 +409,6 @@ func (h *Headscale) pollNetMapStream(
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Err(err).
@@ -370,7 +421,6 @@ func (h *Headscale) pollNetMapStream(
log.Error().
Caller().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Msg("Cannot cast writer to http.Flusher")
@@ -380,7 +430,6 @@ func (h *Headscale) pollNetMapStream(
log.Trace().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Int("bytes", len(data)).
@@ -392,7 +441,6 @@ func (h *Headscale) pollNetMapStream(
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Err(err).
@@ -408,7 +456,6 @@ func (h *Headscale) pollNetMapStream(
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Err(err).
@@ -419,7 +466,6 @@ func (h *Headscale) pollNetMapStream(
log.Trace().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Int("bytes", len(data)).
@@ -428,7 +474,6 @@ func (h *Headscale) pollNetMapStream(
case <-updateChan:
log.Trace().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "update").
Msg("Received a request for update")
@@ -442,16 +487,14 @@ func (h *Headscale) pollNetMapStream(
}
log.Debug().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Time("last_successful_update", lastUpdate).
Time("last_state_change", h.getLastStateChange(machine.Namespace.Name)).
Msgf("There has been updates since the last successful update to %s", machine.Hostname)
data, err := h.getMapResponseData(mapRequest, machine, false)
data, err := h.getMapResponse(machineKey, mapRequest, machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "update").
Err(err).
@@ -463,7 +506,6 @@ func (h *Headscale) pollNetMapStream(
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "update").
Err(err).
@@ -479,7 +521,6 @@ func (h *Headscale) pollNetMapStream(
log.Error().
Caller().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "update").
Msg("Cannot cast writer to http.Flusher")
@@ -489,7 +530,6 @@ func (h *Headscale) pollNetMapStream(
log.Trace().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "update").
Msg("Updated Map has been sent")
@@ -507,7 +547,6 @@ func (h *Headscale) pollNetMapStream(
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "update").
Err(err).
@@ -527,7 +566,6 @@ func (h *Headscale) pollNetMapStream(
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "update").
Err(err).
@@ -542,7 +580,6 @@ func (h *Headscale) pollNetMapStream(
}
log.Trace().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Time("last_successful_update", lastUpdate).
Time("last_state_change", h.getLastStateChange(machine.Namespace.Name)).
@@ -561,7 +598,6 @@ func (h *Headscale) pollNetMapStream(
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "Done").
Err(err).
@@ -577,7 +613,6 @@ func (h *Headscale) pollNetMapStream(
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "Done").
Err(err).
@@ -590,7 +625,6 @@ func (h *Headscale) pollNetMapStream(
case <-h.shutdownChan:
log.Info().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("The long-poll handler is shutting down")
@@ -603,9 +637,9 @@ func (h *Headscale) scheduledPollWorker(
ctx context.Context,
updateChan chan struct{},
keepAliveChan chan []byte,
machineKey key.MachinePublic,
mapRequest tailcfg.MapRequest,
machine *Machine,
isNoise bool,
) {
keepAliveTicker := time.NewTicker(keepAliveInterval)
updateCheckerTicker := time.NewTicker(h.cfg.NodeUpdateCheckInterval)
@@ -627,11 +661,10 @@ func (h *Headscale) scheduledPollWorker(
return
case <-keepAliveTicker.C:
data, err := h.getMapKeepAliveResponseData(mapRequest, machine, isNoise)
data, err := h.getMapKeepAliveResponse(machineKey, mapRequest)
if err != nil {
log.Error().
Str("func", "keepAlive").
Bool("noise", isNoise).
Err(err).
Msg("Error generating the keep alive msg")
@@ -641,7 +674,6 @@ func (h *Headscale) scheduledPollWorker(
log.Debug().
Str("func", "keepAlive").
Str("machine", machine.Hostname).
Bool("noise", isNoise).
Msg("Sending keepalive")
keepAliveChan <- data
@@ -649,7 +681,6 @@ func (h *Headscale) scheduledPollWorker(
log.Debug().
Str("func", "scheduledPollWorker").
Str("machine", machine.Hostname).
Bool("noise", isNoise).
Msg("Sending update request")
updateRequestsFromNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "scheduled-update").
Inc()

View File

@@ -14,10 +14,10 @@ import (
)
const (
ErrPreAuthKeyNotFound = Error("AuthKey not found")
ErrPreAuthKeyExpired = Error("AuthKey expired")
ErrSingleUseAuthKeyHasBeenUsed = Error("AuthKey has already been used")
ErrNamespaceMismatch = Error("namespace mismatch")
errPreAuthKeyNotFound = Error("AuthKey not found")
errPreAuthKeyExpired = Error("AuthKey expired")
errSingleUseAuthKeyHasBeenUsed = Error("AuthKey has already been used")
errNamespaceMismatch = Error("namespace mismatch")
)
// PreAuthKey describes a pre-authorization key usable in a particular namespace.
@@ -92,7 +92,7 @@ func (h *Headscale) GetPreAuthKey(namespace string, key string) (*PreAuthKey, er
}
if pak.Namespace.Name != namespace {
return nil, ErrNamespaceMismatch
return nil, errNamespaceMismatch
}
return pak, nil
@@ -135,11 +135,11 @@ func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) {
result.Error,
gorm.ErrRecordNotFound,
) {
return nil, ErrPreAuthKeyNotFound
return nil, errPreAuthKeyNotFound
}
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
return nil, ErrPreAuthKeyExpired
return nil, errPreAuthKeyExpired
}
if pak.Reusable || pak.Ephemeral { // we don't need to check if has been used before
@@ -152,7 +152,7 @@ func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) {
}
if len(machines) != 0 || pak.Used {
return nil, ErrSingleUseAuthKeyHasBeenUsed
return nil, errSingleUseAuthKeyHasBeenUsed
}
return &pak, nil

View File

@@ -44,13 +44,13 @@ func (*Suite) TestExpiredPreAuthKey(c *check.C) {
c.Assert(err, check.IsNil)
key, err := app.checkKeyValidity(pak.Key)
c.Assert(err, check.Equals, ErrPreAuthKeyExpired)
c.Assert(err, check.Equals, errPreAuthKeyExpired)
c.Assert(key, check.IsNil)
}
func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) {
key, err := app.checkKeyValidity("potatoKey")
c.Assert(err, check.Equals, ErrPreAuthKeyNotFound)
c.Assert(err, check.Equals, errPreAuthKeyNotFound)
c.Assert(key, check.IsNil)
}
@@ -86,7 +86,7 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) {
app.db.Save(&machine)
key, err := app.checkKeyValidity(pak.Key)
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
c.Assert(err, check.Equals, errSingleUseAuthKeyHasBeenUsed)
c.Assert(key, check.IsNil)
}
@@ -174,7 +174,7 @@ func (*Suite) TestExpirePreauthKey(c *check.C) {
c.Assert(pak.Expiration, check.NotNil)
key, err := app.checkKeyValidity(pak.Key)
c.Assert(err, check.Equals, ErrPreAuthKeyExpired)
c.Assert(err, check.Equals, errPreAuthKeyExpired)
c.Assert(key, check.IsNil)
}
@@ -188,5 +188,5 @@ func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
app.db.Save(&pak)
_, err = app.checkKeyValidity(pak.Key)
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
c.Assert(err, check.Equals, errSingleUseAuthKeyHasBeenUsed)
}

View File

@@ -4,12 +4,21 @@ deps:
- remote: buf.build
owner: googleapis
repository: googleapis
commit: 62f35d8aed1149c291d606d958a7ce32
branch: main
commit: cd101b0abb7b4404a0b1ecc1afd4ce10
digest: b1-H4GHwHVHcJBbVPg-Cdmnx812reFCDQws_QoQ0W2hYQA=
create_time: 2021-10-23T15:04:06.087748Z
- remote: buf.build
owner: grpc-ecosystem
repository: grpc-gateway
commit: bc28b723cd774c32b6fbc77621518765
branch: main
commit: ff83506eb9cc4cf8972f49ce87e6ed3e
digest: b1-iLPHgLaoeWWinMiXXqPnxqE4BThtY3eSbswVGh9GOGI=
create_time: 2021-10-23T16:26:52.283938Z
- remote: buf.build
owner: ufoundit-dev
repository: protoc-gen-gorm
branch: main
commit: e2ecbaa0d37843298104bd29fd866df8
digest: b1-SV9yKH_8P-IKTOlHZxP-bb0ALANYeEqH_mtPA0EWfLc=
create_time: 2021-10-08T06:03:05.64876Z

View File

@@ -1,742 +0,0 @@
package headscale
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
const (
// The CapabilityVersion is used by Tailscale clients to indicate
// their codebase version. Tailscale clients can communicate over TS2021
// from CapabilityVersion 28, but we only have good support for it
// since https://github.com/tailscale/tailscale/pull/4323 (Noise in any HTTPS port).
//
// Related to this change, there is https://github.com/tailscale/tailscale/pull/5379,
// where CapabilityVersion 39 is introduced to indicate #4323 was merged.
//
// See also https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go
NoiseCapabilityVersion = 39
)
// KeyHandler provides the Headscale pub key
// Listens in /key.
func (h *Headscale) KeyHandler(
writer http.ResponseWriter,
req *http.Request,
) {
// New Tailscale clients send a 'v' parameter to indicate the CurrentCapabilityVersion
clientCapabilityStr := req.URL.Query().Get("v")
if clientCapabilityStr != "" {
log.Debug().
Str("handler", "/key").
Str("v", clientCapabilityStr).
Msg("New noise client")
clientCapabilityVersion, err := strconv.Atoi(clientCapabilityStr)
if err != nil {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Wrong params"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
// TS2021 (Tailscale v2 protocol) requires to have a different key
if clientCapabilityVersion >= NoiseCapabilityVersion {
resp := tailcfg.OverTLSPublicKeyResponse{
LegacyPublicKey: h.privateKey.Public(),
PublicKey: h.noisePrivateKey.Public(),
}
writer.Header().Set("Content-Type", "application/json")
writer.WriteHeader(http.StatusOK)
err = json.NewEncoder(writer).Encode(resp)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
}
log.Debug().
Str("handler", "/key").
Msg("New legacy client")
// Old clients don't send a 'v' parameter, so we send the legacy public key
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err := writer.Write([]byte(MachinePublicKeyStripPrefix(h.privateKey.Public())))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
}
// handleRegisterCommon is the common logic for registering a client in the legacy and Noise protocols
//
// When using Noise, the machineKey is Zero.
func (h *Headscale) handleRegisterCommon(
writer http.ResponseWriter,
req *http.Request,
registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) {
now := time.Now().UTC()
machine, err := h.GetMachineByAnyNodeKey(registerRequest.NodeKey, registerRequest.OldNodeKey)
if errors.Is(err, gorm.ErrRecordNotFound) {
// If the machine has AuthKey set, handle registration via PreAuthKeys
if registerRequest.Auth.AuthKey != "" {
h.handleAuthKeyCommon(writer, registerRequest, machineKey)
return
}
// Check if the node is waiting for interactive login.
//
// TODO(juan): We could use this field to improve our protocol implementation,
// and hold the request until the client closes it, or the interactive
// login is completed (i.e., the user registers the machine).
// This is not implemented yet, as it is no strictly required. The only side-effect
// is that the client will hammer headscale with requests until it gets a
// successful RegisterResponse.
if registerRequest.Followup != "" {
if _, ok := h.registrationCache.Get(NodePublicKeyStripPrefix(registerRequest.NodeKey)); ok {
log.Debug().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Str("follow_up", registerRequest.Followup).
Bool("noise", machineKey.IsZero()).
Msg("Machine is waiting for interactive login")
ticker := time.NewTicker(registrationHoldoff)
select {
case <-req.Context().Done():
return
case <-ticker.C:
h.handleNewMachineCommon(writer, registerRequest, machineKey)
return
}
}
}
log.Info().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Str("follow_up", registerRequest.Followup).
Bool("noise", machineKey.IsZero()).
Msg("New machine not yet in the database")
givenName, err := h.GenerateGivenName(registerRequest.Hostinfo.Hostname)
if err != nil {
log.Error().
Caller().
Str("func", "RegistrationHandler").
Str("hostinfo.name", registerRequest.Hostinfo.Hostname).
Err(err)
return
}
// The machine did not have a key to authenticate, which means
// that we rely on a method that calls back some how (OpenID or CLI)
// We create the machine and then keep it around until a callback
// happens
newMachine := Machine{
MachineKey: MachinePublicKeyStripPrefix(machineKey),
Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName,
NodeKey: NodePublicKeyStripPrefix(registerRequest.NodeKey),
LastSeen: &now,
Expiry: &time.Time{},
}
if !registerRequest.Expiry.IsZero() {
log.Trace().
Caller().
Bool("noise", machineKey.IsZero()).
Str("machine", registerRequest.Hostinfo.Hostname).
Time("expiry", registerRequest.Expiry).
Msg("Non-zero expiry time requested")
newMachine.Expiry = &registerRequest.Expiry
}
h.registrationCache.Set(
newMachine.NodeKey,
newMachine,
registerCacheExpiration,
)
h.handleNewMachineCommon(writer, registerRequest, machineKey)
return
}
// The machine is already registered, so we need to pass through reauth or key update.
if machine != nil {
// If the NodeKey stored in headscale is the same as the key presented in a registration
// request, then we have a node that is either:
// - Trying to log out (sending a expiry in the past)
// - A valid, registered machine, looking for the node map
// - Expired machine wanting to reauthenticate
if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.NodeKey) {
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !registerRequest.Expiry.IsZero() &&
registerRequest.Expiry.UTC().Before(now) {
h.handleMachineLogOutCommon(writer, *machine, machineKey)
return
}
// If machine is not expired, and is register, we have a already accepted this machine,
// let it proceed with a valid registration
if !machine.isExpired() {
h.handleMachineValidRegistrationCommon(writer, *machine, machineKey)
return
}
}
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.OldNodeKey) &&
!machine.isExpired() {
h.handleMachineRefreshKeyCommon(
writer,
registerRequest,
*machine,
machineKey,
)
return
}
// The machine has expired
h.handleMachineExpiredCommon(writer, registerRequest, *machine, machineKey)
machine.Expiry = &time.Time{}
h.registrationCache.Set(
NodePublicKeyStripPrefix(registerRequest.NodeKey),
*machine,
registerCacheExpiration,
)
return
}
}
// handleAuthKeyCommon contains the logic to manage auth key client registration
// It is used both by the legacy and the new Noise protocol.
// When using Noise, the machineKey is Zero.
//
// TODO: check if any locks are needed around IP allocation.
func (h *Headscale) handleAuthKeyCommon(
writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) {
log.Debug().
Str("func", "handleAuthKeyCommon").
Str("machine", registerRequest.Hostinfo.Hostname).
Bool("noise", machineKey.IsZero()).
Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname)
resp := tailcfg.RegisterResponse{}
pak, err := h.checkKeyValidity(registerRequest.Auth.AuthKey)
if err != nil {
log.Error().
Caller().
Str("func", "handleAuthKeyCommon").
Bool("noise", machineKey.IsZero()).
Str("machine", registerRequest.Hostinfo.Hostname).
Err(err).
Msg("Failed authentication via AuthKey")
resp.MachineAuthorized = false
respBody, err := h.marshalResponse(resp, machineKey)
if err != nil {
log.Error().
Caller().
Str("func", "handleAuthKeyCommon").
Bool("noise", machineKey.IsZero()).
Str("machine", registerRequest.Hostinfo.Hostname).
Err(err).
Msg("Cannot encode message")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc()
return
}
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusUnauthorized)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Caller().
Bool("noise", machineKey.IsZero()).
Err(err).
Msg("Failed to write response")
}
log.Error().
Caller().
Str("func", "handleAuthKeyCommon").
Bool("noise", machineKey.IsZero()).
Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Failed authentication via AuthKey")
if pak != nil {
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc()
} else {
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", "unknown").Inc()
}
return
}
log.Debug().
Str("func", "handleAuthKeyCommon").
Bool("noise", machineKey.IsZero()).
Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Authentication key was valid, proceeding to acquire IP addresses")
nodeKey := NodePublicKeyStripPrefix(registerRequest.NodeKey)
// retrieve machine information if it exist
// The error is not important, because if it does not
// exist, then this is a new machine and we will move
// on to registration.
machine, _ := h.GetMachineByAnyNodeKey(registerRequest.NodeKey, registerRequest.OldNodeKey)
if machine != nil {
log.Trace().
Caller().
Bool("noise", machineKey.IsZero()).
Str("machine", machine.Hostname).
Msg("machine was already registered before, refreshing with new auth key")
machine.NodeKey = nodeKey
machine.AuthKeyID = uint(pak.ID)
err := h.RefreshMachine(machine, registerRequest.Expiry)
if err != nil {
log.Error().
Caller().
Bool("noise", machineKey.IsZero()).
Str("machine", machine.Hostname).
Err(err).
Msg("Failed to refresh machine")
return
}
} else {
now := time.Now().UTC()
givenName, err := h.GenerateGivenName(registerRequest.Hostinfo.Hostname)
if err != nil {
log.Error().
Caller().
Bool("noise", machineKey.IsZero()).
Str("func", "RegistrationHandler").
Str("hostinfo.name", registerRequest.Hostinfo.Hostname).
Err(err)
return
}
machineToRegister := Machine{
Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName,
NamespaceID: pak.Namespace.ID,
MachineKey: MachinePublicKeyStripPrefix(machineKey),
RegisterMethod: RegisterMethodAuthKey,
Expiry: &registerRequest.Expiry,
NodeKey: nodeKey,
LastSeen: &now,
AuthKeyID: uint(pak.ID),
}
machine, err = h.RegisterMachine(
machineToRegister,
)
if err != nil {
log.Error().
Caller().
Bool("noise", machineKey.IsZero()).
Err(err).
Msg("could not register machine")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
}
err = h.UsePreAuthKey(pak)
if err != nil {
log.Error().
Caller().
Bool("noise", machineKey.IsZero()).
Err(err).
Msg("Failed to use pre-auth key")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
resp.MachineAuthorized = true
resp.User = *pak.Namespace.toUser()
respBody, err := h.marshalResponse(resp, machineKey)
if err != nil {
log.Error().
Caller().
Bool("noise", machineKey.IsZero()).
Str("func", "handleAuthKeyCommon").
Str("machine", registerRequest.Hostinfo.Hostname).
Err(err).
Msg("Cannot encode message")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.Namespace.Name).
Inc()
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Caller().
Bool("noise", machineKey.IsZero()).
Err(err).
Msg("Failed to write response")
}
log.Info().
Str("func", "handleAuthKeyCommon").
Bool("noise", machineKey.IsZero()).
Str("machine", registerRequest.Hostinfo.Hostname).
Str("ips", strings.Join(machine.IPAddresses.ToStringSlice(), ", ")).
Msg("Successfully authenticated via AuthKey")
}
// handleNewMachineCommon exposes for both legacy and Noise the functionality to get a URL
// for authorizing the machine. This url is then showed to the user by the local Tailscale client.
func (h *Headscale) handleNewMachineCommon(
writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) {
resp := tailcfg.RegisterResponse{}
// The machine registration is new, redirect the client to the registration URL
log.Debug().
Caller().
Bool("noise", machineKey.IsZero()).
Str("machine", registerRequest.Hostinfo.Hostname).
Msg("The node seems to be new, sending auth url")
if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf(
"%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
NodePublicKeyStripPrefix(registerRequest.NodeKey),
)
} else {
resp.AuthURL = fmt.Sprintf("%s/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
NodePublicKeyStripPrefix(registerRequest.NodeKey))
}
respBody, err := h.marshalResponse(resp, machineKey)
if err != nil {
log.Error().
Caller().
Bool("noise", machineKey.IsZero()).
Err(err).
Msg("Cannot encode message")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Bool("noise", machineKey.IsZero()).
Caller().
Err(err).
Msg("Failed to write response")
}
log.Info().
Caller().
Bool("noise", machineKey.IsZero()).
Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Successfully sent auth url")
}
func (h *Headscale) handleMachineLogOutCommon(
writer http.ResponseWriter,
machine Machine,
machineKey key.MachinePublic,
) {
resp := tailcfg.RegisterResponse{}
log.Info().
Bool("noise", machineKey.IsZero()).
Str("machine", machine.Hostname).
Msg("Client requested logout")
err := h.ExpireMachine(&machine)
if err != nil {
log.Error().
Caller().
Bool("noise", machineKey.IsZero()).
Str("func", "handleMachineLogOutCommon").
Err(err).
Msg("Failed to expire machine")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
resp.AuthURL = ""
resp.MachineAuthorized = false
resp.User = *machine.Namespace.toUser()
respBody, err := h.marshalResponse(resp, machineKey)
if err != nil {
log.Error().
Caller().
Bool("noise", machineKey.IsZero()).
Err(err).
Msg("Cannot encode message")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Bool("noise", machineKey.IsZero()).
Caller().
Err(err).
Msg("Failed to write response")
}
log.Info().
Caller().
Bool("noise", machineKey.IsZero()).
Str("machine", machine.Hostname).
Msg("Successfully logged out")
}
func (h *Headscale) handleMachineValidRegistrationCommon(
writer http.ResponseWriter,
machine Machine,
machineKey key.MachinePublic,
) {
resp := tailcfg.RegisterResponse{}
// The machine registration is valid, respond with redirect to /map
log.Debug().
Caller().
Bool("noise", machineKey.IsZero()).
Str("machine", machine.Hostname).
Msg("Client is registered and we have the current NodeKey. All clear to /map")
resp.AuthURL = ""
resp.MachineAuthorized = true
resp.User = *machine.Namespace.toUser()
resp.Login = *machine.Namespace.toLogin()
respBody, err := h.marshalResponse(resp, machineKey)
if err != nil {
log.Error().
Caller().
Bool("noise", machineKey.IsZero()).
Err(err).
Msg("Cannot encode message")
machineRegistrations.WithLabelValues("update", "web", "error", machine.Namespace.Name).
Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name).
Inc()
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Caller().
Bool("noise", machineKey.IsZero()).
Err(err).
Msg("Failed to write response")
}
log.Info().
Caller().
Bool("noise", machineKey.IsZero()).
Str("machine", machine.Hostname).
Msg("Machine successfully authorized")
}
func (h *Headscale) handleMachineRefreshKeyCommon(
writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest,
machine Machine,
machineKey key.MachinePublic,
) {
resp := tailcfg.RegisterResponse{}
log.Debug().
Caller().
Bool("noise", machineKey.IsZero()).
Str("machine", machine.Hostname).
Msg("We have the OldNodeKey in the database. This is a key refresh")
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
if err := h.db.Save(&machine).Error; err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to update machine key in the database")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
resp.AuthURL = ""
resp.User = *machine.Namespace.toUser()
respBody, err := h.marshalResponse(resp, machineKey)
if err != nil {
log.Error().
Caller().
Bool("noise", machineKey.IsZero()).
Err(err).
Msg("Cannot encode message")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Caller().
Bool("noise", machineKey.IsZero()).
Err(err).
Msg("Failed to write response")
}
log.Info().
Caller().
Bool("noise", machineKey.IsZero()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("old_node_key", registerRequest.OldNodeKey.ShortString()).
Str("machine", machine.Hostname).
Msg("Machine successfully refreshed")
}
func (h *Headscale) handleMachineExpiredCommon(
writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest,
machine Machine,
machineKey key.MachinePublic,
) {
resp := tailcfg.RegisterResponse{}
// The client has registered before, but has expired
log.Debug().
Caller().
Bool("noise", machineKey.IsZero()).
Str("machine", machine.Hostname).
Msg("Machine registration has expired. Sending a authurl to register")
if registerRequest.Auth.AuthKey != "" {
h.handleAuthKeyCommon(writer, registerRequest, machineKey)
return
}
if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
NodePublicKeyStripPrefix(registerRequest.NodeKey))
} else {
resp.AuthURL = fmt.Sprintf("%s/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
NodePublicKeyStripPrefix(registerRequest.NodeKey))
}
respBody, err := h.marshalResponse(resp, machineKey)
if err != nil {
log.Error().
Caller().
Bool("noise", machineKey.IsZero()).
Err(err).
Msg("Cannot encode message")
machineRegistrations.WithLabelValues("reauth", "web", "error", machine.Namespace.Name).
Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
machineRegistrations.WithLabelValues("reauth", "web", "success", machine.Namespace.Name).
Inc()
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Caller().
Bool("noise", machineKey.IsZero()).
Err(err).
Msg("Failed to write response")
}
log.Info().
Caller().
Bool("noise", machineKey.IsZero()).
Str("machine", machine.Hostname).
Msg("Auth URL for reauthenticate successfully sent")
}

View File

@@ -1,122 +0,0 @@
package headscale
import (
"encoding/binary"
"encoding/json"
"github.com/klauspost/compress/zstd"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
func (h *Headscale) getMapResponseData(
mapRequest tailcfg.MapRequest,
machine *Machine,
isNoise bool,
) ([]byte, error) {
mapResponse, err := h.generateMapResponse(mapRequest, machine)
if err != nil {
return nil, err
}
if isNoise {
return h.marshalMapResponse(mapResponse, key.MachinePublic{}, mapRequest.Compress)
}
var machineKey key.MachinePublic
err = machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse client key")
return nil, err
}
return h.marshalMapResponse(mapResponse, machineKey, mapRequest.Compress)
}
func (h *Headscale) getMapKeepAliveResponseData(
mapRequest tailcfg.MapRequest,
machine *Machine,
isNoise bool,
) ([]byte, error) {
keepAliveResponse := tailcfg.MapResponse{
KeepAlive: true,
}
if isNoise {
return h.marshalMapResponse(keepAliveResponse, key.MachinePublic{}, mapRequest.Compress)
}
var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse client key")
return nil, err
}
return h.marshalMapResponse(keepAliveResponse, machineKey, mapRequest.Compress)
}
func (h *Headscale) marshalResponse(
resp interface{},
machineKey key.MachinePublic,
) ([]byte, error) {
jsonBody, err := json.Marshal(resp)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot marshal response")
return nil, err
}
if machineKey.IsZero() { // if Noise
return jsonBody, nil
}
return h.privateKey.SealTo(machineKey, jsonBody), nil
}
func (h *Headscale) marshalMapResponse(
resp interface{},
machineKey key.MachinePublic,
compression string,
) ([]byte, error) {
jsonBody, err := json.Marshal(resp)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot marshal map response")
}
var respBody []byte
if compression == ZstdCompression {
encoder, _ := zstd.NewWriter(nil)
respBody = encoder.EncodeAll(jsonBody, nil)
if !machineKey.IsZero() { // if legacy protocol
respBody = h.privateKey.SealTo(machineKey, respBody)
}
} else {
if !machineKey.IsZero() { // if legacy protocol
respBody = h.privateKey.SealTo(machineKey, jsonBody)
} else {
respBody = jsonBody
}
}
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...)
return data, nil
}

View File

@@ -1,58 +0,0 @@
package headscale
import (
"io"
"net/http"
"github.com/gorilla/mux"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
// RegistrationHandler handles the actual registration process of a machine
// Endpoint /machine/:mkey.
func (h *Headscale) RegistrationHandler(
writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
machineKeyStr, ok := vars["mkey"]
if !ok || machineKeyStr == "" {
log.Error().
Str("handler", "RegistrationHandler").
Msg("No machine ID in request")
http.Error(writer, "No machine ID in request", http.StatusBadRequest)
return
}
body, _ := io.ReadAll(req.Body)
var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr)))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse machine key")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
http.Error(writer, "Cannot parse machine key", http.StatusBadRequest)
return
}
registerRequest := tailcfg.RegisterRequest{}
err = decode(body, &registerRequest, &machineKey, h.privateKey)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot decode message")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
http.Error(writer, "Cannot decode message", http.StatusBadRequest)
return
}
h.handleRegisterCommon(writer, req, registerRequest, machineKey)
}

View File

@@ -1,94 +0,0 @@
package headscale
import (
"errors"
"io"
"net/http"
"github.com/gorilla/mux"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
// PollNetMapHandler takes care of /machine/:id/map
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
//
// The clients POST stuff like HostInfo and their Endpoints here, but
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) PollNetMapHandler(
writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
machineKeyStr, ok := vars["mkey"]
if !ok || machineKeyStr == "" {
log.Error().
Str("handler", "PollNetMap").
Msg("No machine key in request")
http.Error(writer, "No machine key in request", http.StatusBadRequest)
return
}
log.Trace().
Str("handler", "PollNetMap").
Str("id", machineKeyStr).
Msg("PollNetMapHandler called")
body, _ := io.ReadAll(req.Body)
var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr)))
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot parse client key")
http.Error(writer, "Cannot parse client key", http.StatusBadRequest)
return
}
mapRequest := tailcfg.MapRequest{}
err = decode(body, &mapRequest, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot decode message")
http.Error(writer, "Cannot decode message", http.StatusBadRequest)
return
}
machine, err := h.GetMachineByMachineKey(machineKey)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", machineKey.String())
http.Error(writer, "", http.StatusUnauthorized)
return
}
log.Error().
Str("handler", "PollNetMap").
Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String())
http.Error(writer, "", http.StatusInternalServerError)
return
}
log.Trace().
Str("handler", "PollNetMap").
Str("id", machineKeyStr).
Str("machine", machine.Hostname).
Msg("A machine is entering polling via the legacy protocol")
h.handlePollCommon(writer, req.Context(), machine, mapRequest, false)
}

View File

@@ -1,38 +0,0 @@
package headscale
import (
"encoding/json"
"io"
"net/http"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
// // NoiseRegistrationHandler handles the actual registration process of a machine.
func (h *Headscale) NoiseRegistrationHandler(
writer http.ResponseWriter,
req *http.Request,
) {
log.Trace().Caller().Msgf("Noise registration handler for client %s", req.RemoteAddr)
if req.Method != http.MethodPost {
http.Error(writer, "Wrong method", http.StatusMethodNotAllowed)
return
}
body, _ := io.ReadAll(req.Body)
registerRequest := tailcfg.RegisterRequest{}
if err := json.Unmarshal(body, &registerRequest); err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse RegisterRequest")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
http.Error(writer, "Internal error", http.StatusInternalServerError)
return
}
h.handleRegisterCommon(writer, req, registerRequest, key.MachinePublic{})
}

View File

@@ -1,67 +0,0 @@
package headscale
import (
"encoding/json"
"errors"
"io"
"net/http"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
// NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
//
// The clients POST stuff like HostInfo and their Endpoints here, but
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) NoisePollNetMapHandler(
writer http.ResponseWriter,
req *http.Request,
) {
log.Trace().
Str("handler", "NoisePollNetMap").
Msg("PollNetMapHandler called")
body, _ := io.ReadAll(req.Body)
mapRequest := tailcfg.MapRequest{}
if err := json.Unmarshal(body, &mapRequest); err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse MapRequest")
http.Error(writer, "Internal error", http.StatusInternalServerError)
return
}
machine, err := h.GetMachineByAnyNodeKey(mapRequest.NodeKey, key.NodePublic{})
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "NoisePollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", mapRequest.NodeKey.String())
http.Error(writer, "Internal error", http.StatusNotFound)
return
}
log.Error().
Str("handler", "NoisePollNetMap").
Msgf("Failed to fetch machine from the database with node key: %s", mapRequest.NodeKey.String())
http.Error(writer, "Internal error", http.StatusInternalServerError)
return
}
log.Debug().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Msg("A machine is entering polling via the Noise protocol")
h.handlePollCommon(writer, req.Context(), machine, mapRequest, true)
}

View File

@@ -2,11 +2,12 @@ package headscale
import (
"fmt"
"net/netip"
"inet.af/netaddr"
)
const (
ErrRouteIsNotAvailable = Error("route is not available")
errRouteIsNotAvailable = Error("route is not available")
)
// Deprecated: use machine function instead
@@ -15,7 +16,7 @@ const (
func (h *Headscale) GetAdvertisedNodeRoutes(
namespace string,
nodeName string,
) (*[]netip.Prefix, error) {
) (*[]netaddr.IPPrefix, error) {
machine, err := h.GetMachine(namespace, nodeName)
if err != nil {
return nil, err
@@ -30,7 +31,7 @@ func (h *Headscale) GetAdvertisedNodeRoutes(
func (h *Headscale) GetEnabledNodeRoutes(
namespace string,
nodeName string,
) ([]netip.Prefix, error) {
) ([]netaddr.IPPrefix, error) {
machine, err := h.GetMachine(namespace, nodeName)
if err != nil {
return nil, err
@@ -46,7 +47,7 @@ func (h *Headscale) IsNodeRouteEnabled(
nodeName string,
routeStr string,
) bool {
route, err := netip.ParsePrefix(routeStr)
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return false
}
@@ -78,7 +79,7 @@ func (h *Headscale) EnableNodeRoute(
return err
}
route, err := netip.ParsePrefix(routeStr)
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return err
}
@@ -105,7 +106,7 @@ func (h *Headscale) EnableNodeRoute(
}
if !available {
return ErrRouteIsNotAvailable
return errRouteIsNotAvailable
}
machine.EnabledRoutes = enabledRoutes

View File

@@ -1,9 +1,8 @@
package headscale
import (
"net/netip"
"gopkg.in/check.v1"
"inet.af/netaddr"
"tailscale.com/tailcfg"
)
@@ -17,11 +16,11 @@ func (s *Suite) TestGetRoutes(c *check.C) {
_, err = app.GetMachine("test", "test_get_route_machine")
c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix("10.0.0.0/24")
route, err := netaddr.ParseIPPrefix("10.0.0.0/24")
c.Assert(err, check.IsNil)
hostInfo := tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{route},
RoutableIPs: []netaddr.IPPrefix{route},
}
machine := Machine{
@@ -61,18 +60,18 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
_, err = app.GetMachine("test", "test_enable_route_machine")
c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix(
route, err := netaddr.ParseIPPrefix(
"10.0.0.0/24",
)
c.Assert(err, check.IsNil)
route2, err := netip.ParsePrefix(
route2, err := netaddr.ParseIPPrefix(
"150.0.10.0/25",
)
c.Assert(err, check.IsNil)
hostInfo := tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{route, route2},
RoutableIPs: []netaddr.IPPrefix{route, route2},
}
machine := Machine{

View File

@@ -83,7 +83,7 @@ func SwaggerAPIv1(
writer http.ResponseWriter,
req *http.Request,
) {
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.Header().Set("Content-Type", "application/json; charset=utf-88")
writer.WriteHeader(http.StatusOK)
if _, err := writer.Write(apiV1JSON); err != nil {
log.Error().

View File

@@ -13,7 +13,6 @@ import (
"fmt"
"io/fs"
"net"
"net/netip"
"os"
"path/filepath"
"reflect"
@@ -22,14 +21,14 @@ import (
"github.com/rs/zerolog/log"
"github.com/spf13/viper"
"go4.org/netipx"
"inet.af/netaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
const (
ErrCannotDecryptResponse = Error("cannot decrypt response")
ErrCouldNotAllocateIP = Error("could not find any suitable IP")
errCannotDecryptReponse = Error("cannot decrypt response")
errCouldNotAllocateIP = Error("could not find any suitable IP")
// These constants are copied from the upstream tailscale.com/types/key
// library, because they are not exported.
@@ -60,8 +59,6 @@ const (
privateHexPrefix = "privkey:"
PermissionFallback = 0o700
ZstdCompression = "zstd"
)
func MachinePublicKeyStripPrefix(machineKey key.MachinePublic) string {
@@ -119,14 +116,11 @@ func decode(
pubKey *key.MachinePublic,
privKey *key.MachinePrivate,
) error {
log.Trace().
Str("pubkey", pubKey.ShortString()).
Int("length", len(msg)).
Msg("Trying to decrypt")
log.Trace().Int("length", len(msg)).Msg("Trying to decrypt")
decrypted, ok := privKey.OpenFrom(*pubKey, msg)
if !ok {
return ErrCannotDecryptResponse
return errCannotDecryptReponse
}
if err := json.Unmarshal(decrypted, output); err != nil {
@@ -136,12 +130,25 @@ func decode(
return nil
}
func encode(
v interface{},
pubKey *key.MachinePublic,
privKey *key.MachinePrivate,
) ([]byte, error) {
b, err := json.Marshal(v)
if err != nil {
return nil, err
}
return privKey.SealTo(*pubKey, b), nil
}
func (h *Headscale) getAvailableIPs() (MachineAddresses, error) {
var ips MachineAddresses
var err error
ipPrefixes := h.cfg.IPPrefixes
for _, ipPrefix := range ipPrefixes {
var ip *netip.Addr
var ip *netaddr.IP
ip, err = h.getAvailableIP(ipPrefix)
if err != nil {
return ips, err
@@ -152,16 +159,16 @@ func (h *Headscale) getAvailableIPs() (MachineAddresses, error) {
return ips, err
}
func GetIPPrefixEndpoints(na netip.Prefix) (netip.Addr, netip.Addr) {
var network, broadcast netip.Addr
ipRange := netipx.RangeOfPrefix(na)
func GetIPPrefixEndpoints(na netaddr.IPPrefix) (netaddr.IP, netaddr.IP) {
var network, broadcast netaddr.IP
ipRange := na.Range()
network = ipRange.From()
broadcast = ipRange.To()
return network, broadcast
}
func (h *Headscale) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, error) {
func (h *Headscale) getAvailableIP(ipPrefix netaddr.IPPrefix) (*netaddr.IP, error) {
usedIps, err := h.getUsedIPs()
if err != nil {
return nil, err
@@ -174,7 +181,7 @@ func (h *Headscale) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, error) {
for {
if !ipPrefix.Contains(ip) {
return nil, ErrCouldNotAllocateIP
return nil, errCouldNotAllocateIP
}
switch {
@@ -182,7 +189,7 @@ func (h *Headscale) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, error) {
fallthrough
case usedIps.Contains(ip):
fallthrough
case ip == netip.Addr{} || ip.IsLoopback():
case ip.IsZero() || ip.IsLoopback():
ip = ip.Next()
continue
@@ -193,19 +200,19 @@ func (h *Headscale) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, error) {
}
}
func (h *Headscale) getUsedIPs() (*netipx.IPSet, error) {
func (h *Headscale) getUsedIPs() (*netaddr.IPSet, error) {
// FIXME: This really deserves a better data model,
// but this was quick to get running and it should be enough
// to begin experimenting with a dual stack tailnet.
var addressesSlices []string
h.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices)
var ips netipx.IPSetBuilder
var ips netaddr.IPSetBuilder
for _, slice := range addressesSlices {
var machineAddresses MachineAddresses
err := machineAddresses.Scan(slice)
if err != nil {
return &netipx.IPSet{}, fmt.Errorf(
return &netaddr.IPSet{}, fmt.Errorf(
"failed to read ip from database: %w",
err,
)
@@ -218,7 +225,7 @@ func (h *Headscale) getUsedIPs() (*netipx.IPSet, error) {
ipSet, err := ips.IPSet()
if err != nil {
return &netipx.IPSet{}, fmt.Errorf(
return &netaddr.IPSet{}, fmt.Errorf(
"failed to build IP Set: %w",
err,
)
@@ -251,7 +258,7 @@ func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
return d.DialContext(ctx, "unix", addr)
}
func ipPrefixToString(prefixes []netip.Prefix) []string {
func ipPrefixToString(prefixes []netaddr.IPPrefix) []string {
result := make([]string, len(prefixes))
for index, prefix := range prefixes {
@@ -261,13 +268,13 @@ func ipPrefixToString(prefixes []netip.Prefix) []string {
return result
}
func stringToIPPrefix(prefixes []string) ([]netip.Prefix, error) {
result := make([]netip.Prefix, len(prefixes))
func stringToIPPrefix(prefixes []string) ([]netaddr.IPPrefix, error) {
result := make([]netaddr.IPPrefix, len(prefixes))
for index, prefixStr := range prefixes {
prefix, err := netip.ParsePrefix(prefixStr)
prefix, err := netaddr.ParseIPPrefix(prefixStr)
if err != nil {
return []netip.Prefix{}, err
return []netaddr.IPPrefix{}, err
}
result[index] = prefix
@@ -276,7 +283,7 @@ func stringToIPPrefix(prefixes []string) ([]netip.Prefix, error) {
return result, nil
}
func contains[T string | netip.Prefix](ts []T, t T) bool {
func contains[T string | netaddr.IPPrefix](ts []T, t T) bool {
for _, v := range ts {
if reflect.DeepEqual(v, t) {
return true

View File

@@ -1,10 +1,8 @@
package headscale
import (
"net/netip"
"go4.org/netipx"
"gopkg.in/check.v1"
"inet.af/netaddr"
)
func (s *Suite) TestGetAvailableIp(c *check.C) {
@@ -12,7 +10,7 @@ func (s *Suite) TestGetAvailableIp(c *check.C) {
c.Assert(err, check.IsNil)
expected := netip.MustParseAddr("10.27.0.1")
expected := netaddr.MustParseIP("10.27.0.1")
c.Assert(len(ips), check.Equals, 1)
c.Assert(ips[0].String(), check.Equals, expected.String())
@@ -48,8 +46,8 @@ func (s *Suite) TestGetUsedIps(c *check.C) {
c.Assert(err, check.IsNil)
expected := netip.MustParseAddr("10.27.0.1")
expectedIPSetBuilder := netipx.IPSetBuilder{}
expected := netaddr.MustParseIP("10.27.0.1")
expectedIPSetBuilder := netaddr.IPSetBuilder{}
expectedIPSetBuilder.Add(expected)
expectedIPSet, _ := expectedIPSetBuilder.IPSet()
@@ -98,11 +96,11 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
usedIps, err := app.getUsedIPs()
c.Assert(err, check.IsNil)
expected0 := netip.MustParseAddr("10.27.0.1")
expected9 := netip.MustParseAddr("10.27.0.10")
expected300 := netip.MustParseAddr("10.27.0.45")
expected0 := netaddr.MustParseIP("10.27.0.1")
expected9 := netaddr.MustParseIP("10.27.0.10")
expected300 := netaddr.MustParseIP("10.27.0.45")
notExpectedIPSetBuilder := netipx.IPSetBuilder{}
notExpectedIPSetBuilder := netaddr.IPSetBuilder{}
notExpectedIPSetBuilder.Add(expected0)
notExpectedIPSetBuilder.Add(expected9)
notExpectedIPSetBuilder.Add(expected300)
@@ -123,7 +121,7 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
c.Assert(
machine1.IPAddresses[0],
check.Equals,
netip.MustParseAddr("10.27.0.1"),
netaddr.MustParseIP("10.27.0.1"),
)
machine50, err := app.GetMachineByID(50)
@@ -132,10 +130,10 @@ func (s *Suite) TestGetMultiIp(c *check.C) {
c.Assert(
machine50.IPAddresses[0],
check.Equals,
netip.MustParseAddr("10.27.0.50"),
netaddr.MustParseIP("10.27.0.50"),
)
expectedNextIP := netip.MustParseAddr("10.27.1.95")
expectedNextIP := netaddr.MustParseIP("10.27.1.95")
nextIP, err := app.getAvailableIPs()
c.Assert(err, check.IsNil)
@@ -155,7 +153,7 @@ func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
ips, err := app.getAvailableIPs()
c.Assert(err, check.IsNil)
expected := netip.MustParseAddr("10.27.0.1")
expected := netaddr.MustParseIP("10.27.0.1")
c.Assert(len(ips), check.Equals, 1)
c.Assert(ips[0].String(), check.Equals, expected.String())