Compare commits

...

64 Commits

Author SHA1 Message Date
Juan Font
e27753e46e Merge pull request #103 from juanfont/shared-nodes
Add support for sharing nodes across namespaces
2021-09-11 23:31:37 +02:00
Juan Font
11fbef4bf0 Added extra timeout 2021-09-11 23:21:45 +02:00
Juan Font
c4e6ad1ec7 Fixed some typos 2021-09-10 00:52:08 +02:00
Juan Font
263a3f1983 Merge branch 'main' into shared-nodes 2021-09-10 00:49:50 +02:00
Juan Font
8acaea0fbe Increased timeout 2021-09-10 00:44:27 +02:00
Juan Font
bd6adfaec6 Changes a few more variables 2021-09-10 00:37:01 +02:00
Juan Font
4b4a5a4b93 Update sharing.go
Co-authored-by: Kristoffer Dalby <kradalby@kradalby.no>
2021-09-10 00:32:42 +02:00
Juan Font
b098d84557 Apply suggestions from code review
Changed more variable names

Co-authored-by: Kristoffer Dalby <kradalby@kradalby.no>
2021-09-10 00:32:06 +02:00
Juan Font
b937f9b762 Update machine.go
Added comment on toNode
2021-09-10 00:30:02 +02:00
Juan Font
55f3e07bd4 Apply suggestions from code review
Removed one letter variables

Co-authored-by: Kristoffer Dalby <kradalby@kradalby.no>
2021-09-10 00:26:46 +02:00
Juan Font
2780623076 Renamed SharedNode to SharedMachine 2021-09-06 14:43:43 +02:00
Juan Font
75a342f96e Renamed files 2021-09-06 14:40:37 +02:00
Juan Font
729cd54401 Renamed sharing function 2021-09-06 14:39:52 +02:00
Juan Font
a023f51971 Merge pull request #101 from SilverBut/main
fix: check last seen time without possible null pointer
2021-09-03 10:35:49 +02:00
Juan Font
5076eb9215 Merge pull request #102 from SilverBut/patch-1
docs: add notes on how to build own DERP server
2021-09-03 10:24:32 +02:00
Juan Font
7edd0cd14c Added add node cli 2021-09-03 10:23:45 +02:00
Juan Font
7ce4738d8a Preload namespace so the name can be shown 2021-09-03 10:23:26 +02:00
Juan Font
7287e0259c Minor linting issues 2021-09-02 17:08:39 +02:00
Juan Font
d86de68b40 Show namespace in node list table 2021-09-02 17:06:47 +02:00
Juan Font
4ba107a765 README updated 2021-09-02 17:00:46 +02:00
Juan Font
187b016d09 Added helper function to get list of shared nodes 2021-09-02 16:59:50 +02:00
Juan Font
7010f5afad Added unit tests on sharing nodes 2021-09-02 16:59:12 +02:00
Juan Font
48b73fa12f Implement node sharing functionality 2021-09-02 16:59:03 +02:00
Juan Font
1ecd0d7ca4 Added DB SharedNode model to support sharing nodes 2021-09-02 16:57:26 +02:00
Silver Bullet
6faaae0c5f docs: add notes on how to build own DERP server
The official doc is hidden under a bunch of issues. Add a doc link here and hope it could be helpful.
2021-09-02 06:08:12 +08:00
Silver Bullet
e4ef65be76 fix: check last seen time without possible null pointer 2021-09-02 05:44:42 +08:00
Juan Font
39c661d408 Merge pull request #99 from juanfont/explicit-ubuntu-version
Use explicit version in Dockerfile
2021-08-26 21:18:16 +02:00
Juan Font
91a48d6a43 Update Dockerfile
Use explicit version in Dockerfile (addresses #95)
2021-08-26 10:23:45 +02:00
Kristoffer Dalby
123f0fa185 Merge pull request #98 from kradalby/initial-dns-server-exit-node 2021-08-25 22:58:25 +01:00
Kristoffer Dalby
ba3dffecbf Update readme 2021-08-25 19:05:10 +01:00
Kristoffer Dalby
8735e5675c Add a test for the getdnsconfig function 2021-08-25 19:03:04 +01:00
Kristoffer Dalby
3f5e06a0f8 Dont add the portnumber to the ip 2021-08-25 18:43:13 +01:00
Juan Font
ba40a40b73 Merge pull request #96 from qbit/version_fix
Fix setting of version
2021-08-25 12:34:34 +02:00
Kristoffer Dalby
b3732e7fb9 Add nameserver as resolver aswell 2021-08-25 07:04:48 +01:00
Aaron Bieber
104776ee84 fix setting of version 2021-08-24 07:49:15 -06:00
Kristoffer Dalby
01e781e546 Pass DNSConfig to nodes in MapResponse 2021-08-24 07:11:45 +01:00
Kristoffer Dalby
e77c16b55a Add DNSConfig to example and setup test 2021-08-24 07:10:09 +01:00
Kristoffer Dalby
987bbee1db Add DNSConfig field to configuration 2021-08-24 07:09:47 +01:00
Juan Font
74d2fe1baa Merge pull request #84 from kradalby/integration-tests-ci
Improve logic to keep nodes up to date with the network state
2021-08-23 09:42:07 +02:00
Kristoffer Dalby
98e63d5561 Merge pull request #94 from kradalby/split-lint-test
Split lint and test CI files
2021-08-23 07:46:11 +01:00
Kristoffer Dalby
059f13fc9d Add missing comment for stream function 2021-08-23 07:38:14 +01:00
Kristoffer Dalby
ebd27b46af Add comment to updatemachine 2021-08-23 07:35:44 +01:00
Juan Font
ca8d814918 Merge pull request #92 from kradalby/enhance-route-command
Enhance route command with ptables and multiple routes
2021-08-22 12:48:30 +02:00
Kristoffer Dalby
0aeeaac361 Always load machine object from DB before save/modify
We are currently holding Machine objects in memory for a long time,
while waiting for stream/longpoll, this might make us end up with stale
objects, that we just call save on, potentially overwriting stuff in
the database.

A typical scenario would be someone changing something from the CLI,
e.g. enabling routes, which in turn is overwritten again by the stale
object in the longpolling function.

The code has been left with TODO's and a discussion is available in #93.
2021-08-21 16:52:19 +01:00
Kristoffer Dalby
28ed8a5742 Actually rename lint 2021-08-21 15:42:23 +01:00
Kristoffer Dalby
f749be1490 Split lint and test CI files
This commit splits the lint and test steps into two different jobs in
github actions.

Consider this a suggestion, the idea is that when we look at PRs we will
see explicitly which one of the two types of checks fails without having
to open Github actions.
2021-08-21 15:40:27 +01:00
Kristoffer Dalby
693bce1b10 Update test machine name properly 2021-08-21 15:35:26 +01:00
Kristoffer Dalby
4f97e077db Add --all flag to routes enable command to enable all advertised routes 2021-08-21 15:04:30 +01:00
Kristoffer Dalby
c883e79884 Enhance route command with ptables and multiple routes
This commit rewrites the `routes list` command to use ptables to present
a slightly nicer list, including a new field if the route is enabled or
not (which is quite useful).

In addition, it reworks the enable command to support enabling multiple
routes (not only one route as per removed TODO). This allows users to
actually take advantage of exit-nodes and subnet relays.
2021-08-21 14:49:46 +01:00
Kristoffer Dalby
a054e2514a Keep tailscale count at 25 in integration tests 2021-08-21 09:26:18 +01:00
Kristoffer Dalby
c49fe26da7 Code clean up, loglevel debug for integration tests 2021-08-21 09:15:16 +01:00
Kristoffer Dalby
d93a7f2e02 Make Info default log level 2021-08-20 17:15:07 +01:00
Kristoffer Dalby
88d7ac04bf Account for racecondition in deleting/closing update channel
This commit tries to address the possible raceondition  that can happen
if a client closes its connection after we have fetched it from the
syncmap before sending the message.

To try to avoid introducing new dead lock conditions, all messages sent
to updateChannel has been moved into a function, which handles the
locking (instead of calling it all over the place)

The same lock is used around the delete/close function.
2021-08-20 16:52:34 +01:00
Kristoffer Dalby
1f422af1c8 Save headscale logs if jobs fail 2021-08-20 16:50:55 +01:00
Kristoffer Dalby
53168d54d8 Make http timeout 30s instead of 10s 2021-08-19 22:29:03 +01:00
Kristoffer Dalby
b0ec945dbb Make lastStateChange namespaced 2021-08-19 18:19:26 +01:00
Kristoffer Dalby
48ef6e5a6f Rename keepAlive function, as it now does more things 2021-08-19 18:06:57 +01:00
Kristoffer Dalby
8d1adaaef3 Move isOutdated logic to updateChan consumation 2021-08-19 18:05:33 +01:00
Kristoffer Dalby
dd8c0d1e9e Move most "poll" functionality to poll.go
This function migrates more poll functions (including keepalive) to
poll.go to keep it somehow in the same file.

In addition it makes changes to improve the stability and ensure nodes
get the appropriate updates from the headscale control and are not left
in an inconsistent state.

Two new additions is:

omitpeers=true will now trigger an update if the clients are not already up
to date

keepalive has been extended with a timer that will check every 120s if
all nodes are up to date.
2021-08-18 23:24:22 +01:00
Kristoffer Dalby
57b79aa852 Set timeout, add lastupdate field
This commit makes two reasonably major changes:

Set a default timeout for the go HTTP server (which gin uses), which
allows us to actually have broken long poll sessions fail so we can have
the client re-establish them.
The current 10s number is chosen randomly and we need more testing to
ensure that the feature work as intended.

The second is adding a last updated field to keep track of the last time
we had an update that needs to be propagated to all of our
clients/nodes. This will be used to keep track of our machines and if
they are up to date or need us to push an update.
2021-08-18 23:21:11 +01:00
Kristoffer Dalby
2f883410d2 Add lastUpdate field to machine, function issue message on update channel
This commit adds a new field to machine, lastSuccessfulUpdate which
tracks when we last was able to send a proper mapupdate to the node. The
purpose of this is to be able to compare to a "global" last updated time
and determine if we need to send an update map request to a node.

In addition it allows us to create a scheduled check to see if all known
nodes are up to date.

Also, add a helper function to send a message to the update channel of a
machine.
2021-08-18 23:17:38 +01:00
Kristoffer Dalby
6fa61380b2 Up client count, make arguments more explicit and clean up unused assignments 2021-08-18 23:17:09 +01:00
Kristoffer Dalby
7d1a5c00a0 Try with longer timeout 2021-08-13 16:56:28 +01:00
Kristoffer Dalby
036061664e initial integration test file 2021-08-13 16:12:01 +01:00
27 changed files with 1648 additions and 395 deletions

41
.github/workflows/lint.yml vendored Normal file
View File

@@ -0,0 +1,41 @@
name: CI
on: [push, pull_request]
jobs:
# The "build" workflow
lint:
# The type of runner that the job will run on
runs-on: ubuntu-latest
# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2
# Install and run golangci-lint as a separate step, it's much faster this
# way because this action has caching. It'll get run again in `make lint`
# below, but it's still much faster in the end than installing
# golangci-lint manually in the `Run lint` step.
- uses: golangci/golangci-lint-action@v2
with:
args: --timeout 4m
# Setup Go
- name: Setup Go
uses: actions/setup-go@v2
with:
go-version: "1.16.3" # The Go version to download (if necessary) and use.
# Install all the dependencies
- name: Install dependencies
run: |
go version
go install golang.org/x/lint/golint@latest
sudo apt update
sudo apt install -y make
- name: Run lint
with:
args: --timeout 4m
run: make lint

23
.github/workflows/test-integration.yml vendored Normal file
View File

@@ -0,0 +1,23 @@
name: CI
on: [pull_request]
jobs:
# The "build" workflow
integration-test:
# The type of runner that the job will run on
runs-on: ubuntu-latest
# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2
# Setup Go
- name: Setup Go
uses: actions/setup-go@v2
with:
go-version: "1.16.3"
- name: Run Integration tests
run: go test -tags integration -timeout 30m

View File

@@ -10,36 +10,24 @@ jobs:
# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2
# Install and run golangci-lint as a separate step, it's much faster this
# way because this action has caching. It'll get run again in `make lint`
# below, but it's still much faster in the end than installing
# golangci-lint manually in the `Run lint` step.
- uses: golangci/golangci-lint-action@v2
with:
args: --timeout 2m
# Setup Go
- name: Setup Go
uses: actions/setup-go@v2
with:
go-version: '1.16.3' # The Go version to download (if necessary) and use.
# Setup Go
- name: Setup Go
uses: actions/setup-go@v2
with:
go-version: "1.16.3" # The Go version to download (if necessary) and use.
# Install all the dependencies
- name: Install dependencies
run: |
go version
go install golang.org/x/lint/golint@latest
sudo apt update
sudo apt install -y make
- name: Run tests
run: make test
# Install all the dependencies
- name: Install dependencies
run: |
go version
sudo apt update
sudo apt install -y make
- name: Run lint
run: make lint
- name: Run tests
run: make test
- name: Run build
run: make
- name: Run build
run: make

2
.gitignore vendored
View File

@@ -19,3 +19,5 @@ config.json
*.key
/db.sqlite
*.sqlite3
test_output/

View File

@@ -10,7 +10,7 @@ COPY . /go/src/headscale
RUN go install -a -ldflags="-extldflags=-static" -tags netgo,sqlite_omit_load_extension ./cmd/headscale
RUN test -e /go/bin/headscale
FROM ubuntu:latest
FROM ubuntu:20.04
COPY --from=build /go/bin/headscale /usr/local/bin/headscale
ENV TZ UTC

View File

@@ -2,7 +2,7 @@
version = $(shell ./scripts/version-at-commit.sh)
build:
go build -ldflags "-s -w -X main.version=$(version)" cmd/headscale/headscale.go
go build -ldflags "-s -w -X github.com/juanfont/headscale/cmd/headscale/cli.version=$(version)" cmd/headscale/headscale.go
dev: lint test build

View File

@@ -25,14 +25,13 @@ Headscale implements this coordination server.
- [X] JSON-formatted output
- [X] ACLs
- [X] Support for alternative IP ranges in the tailnets (default Tailscale's 100.64.0.0/10)
- [ ] Share nodes between ~~users~~ namespaces
- [ ] DNS
- [X] DNS (passing DNS servers to nodes)
- [X] Share nodes between ~~users~~ namespaces
- [ ] MagicDNS / Smart DNS
## Roadmap 🤷
We are now focusing on adding integration tests with the official clients.
Suggestions/PRs welcomed!

222
api.go
View File

@@ -13,9 +13,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
"gorm.io/datatypes"
"gorm.io/gorm"
"inet.af/netaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/wgkey"
)
@@ -35,8 +33,6 @@ func (h *Headscale) RegisterWebAPI(c *gin.Context) {
return
}
// spew.Dump(c.Params)
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
<html>
<body>
@@ -82,14 +78,16 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
return
}
now := time.Now().UTC()
var m Machine
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
m = Machine{
Expiry: &req.Expiry,
MachineKey: mKey.HexString(),
Name: req.Hostinfo.Hostname,
NodeKey: wgkey.Key(req.NodeKey).HexString(),
Expiry: &req.Expiry,
MachineKey: mKey.HexString(),
Name: req.Hostinfo.Hostname,
NodeKey: wgkey.Key(req.NodeKey).HexString(),
LastSuccessfulUpdate: &now,
}
if err := h.db.Create(&m).Error; err != nil {
log.Error().
@@ -215,202 +213,12 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
c.Data(200, "application/json; charset=utf-8", respBody)
}
// PollNetMapHandler takes care of /machine/:id/map
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
//
// The clients POST stuff like HostInfo and their Endpoints here, but
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) PollNetMapHandler(c *gin.Context) {
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Msg("PollNetMapHandler called")
body, _ := io.ReadAll(c.Request.Body)
mKeyStr := c.Param("id")
mKey, err := wgkey.ParseHex(mKeyStr)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot parse client key")
c.String(http.StatusBadRequest, "")
return
}
req := tailcfg.MapRequest{}
err = decode(body, &req, &mKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot decode message")
c.String(http.StatusBadRequest, "")
return
}
var m Machine
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString())
c.String(http.StatusUnauthorized, "")
return
}
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Found machine in database")
hostinfo, _ := json.Marshal(req.Hostinfo)
m.Name = req.Hostinfo.Hostname
m.HostInfo = datatypes.JSON(hostinfo)
m.DiscoKey = wgkey.Key(req.DiscoKey).HexString()
now := time.Now().UTC()
// From Tailscale client:
//
// ReadOnly is whether the client just wants to fetch the MapResponse,
// without updating their Endpoints. The Endpoints field will be ignored and
// LastSeen will not be updated and peers will not be notified of changes.
//
// The intended use is for clients to discover the DERP map at start-up
// before their first real endpoint update.
if !req.ReadOnly {
endpoints, _ := json.Marshal(req.Endpoints)
m.Endpoints = datatypes.JSON(endpoints)
m.LastSeen = &now
}
h.db.Save(&m)
data, err := h.getMapResponse(mKey, req, m)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Err(err).
Msg("Failed to get Map response")
c.String(http.StatusInternalServerError, ":(")
return
}
// We update our peers if the client is not sending ReadOnly in the MapRequest
// so we don't distribute its initial request (it comes with
// empty endpoints to peers)
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
log.Debug().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Bool("readOnly", req.ReadOnly).
Bool("omitPeers", req.OmitPeers).
Bool("stream", req.Stream).
Msg("Client map request processed")
if req.ReadOnly {
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client is starting up. Asking for DERP map")
c.Data(200, "application/json; charset=utf-8", *data)
return
}
if req.OmitPeers && !req.Stream {
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client sent endpoint update and is ok with a response without peer list")
c.Data(200, "application/json; charset=utf-8", *data)
return
} else if req.OmitPeers && req.Stream {
log.Warn().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Ignoring request, don't know how to handle it")
c.String(http.StatusBadRequest, "")
return
}
// Only create update channel if it has not been created
var update chan []byte
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Creating or loading update channel")
if result, ok := h.clientsPolling.LoadOrStore(m.ID, make(chan []byte, 1)); ok {
update = result.(chan []byte)
}
pollData := make(chan []byte, 1)
defer close(pollData)
cancelKeepAlive := make(chan []byte, 1)
defer close(cancelKeepAlive)
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client is ready to access the tailnet")
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Sending initial map")
pollData <- *data
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Notifying peers")
// TODO: Why does this block?
go h.notifyChangesToPeers(&m)
h.PollNetMapStream(c, m, req, mKey, pollData, update, cancelKeepAlive)
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Finished stream, closing PollNetMap session")
}
func (h *Headscale) keepAlive(cancel chan []byte, pollData chan []byte, mKey wgkey.Key, req tailcfg.MapRequest, m Machine) {
for {
select {
case <-cancel:
return
default:
data, err := h.getMapKeepAliveResponse(mKey, req, m)
if err != nil {
log.Error().
Str("func", "keepAlive").
Err(err).
Msg("Error generating the keep alive msg")
return
}
log.Debug().
Str("func", "keepAlive").
Str("machine", m.Name).
Msg("Sending keepalive")
pollData <- *data
time.Sleep(60 * time.Second)
}
}
}
func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m Machine) (*[]byte, error) {
log.Trace().
Str("func", "getMapResponse").
Str("machine", req.Hostinfo.Hostname).
Msg("Creating Map response")
node, err := m.toNode()
node, err := m.toNode(true)
if err != nil {
log.Error().
Str("func", "getMapResponse").
@@ -434,10 +242,15 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m Mac
}
resp := tailcfg.MapResponse{
KeepAlive: false,
Node: node,
Peers: *peers,
DNS: []netaddr.IP{},
KeepAlive: false,
Node: node,
Peers: *peers,
//TODO(kradalby): As per tailscale docs, if DNSConfig is nil,
// it means its not updated, maybe we can have some logic
// to check and only pass updates when its updates.
// This is probably more relevant if we try to implement
// "MagicDNS"
DNSConfig: h.cfg.DNSConfig,
SearchPaths: []string{},
Domain: "headscale.net",
PacketFilter: *h.aclRules,
@@ -465,7 +278,6 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m Mac
return nil, err
}
}
// spew.Dump(resp)
// declare the incoming size on the first 4 bytes
data := make([]byte, 4)
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
@@ -542,7 +354,7 @@ func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key,
Str("func", "handleAuthKey").
Str("machine", m.Name).
Str("ip", ip.String()).
Msgf("Assining %s to %s", ip, m.Name)
Msgf("Assigning %s to %s", ip, m.Name)
m.AuthKeyID = uint(pak.ID)
m.IPAddress = ip.String()

45
app.go
View File

@@ -43,6 +43,8 @@ type Config struct {
TLSCertPath string
TLSKeyPath string
DNSConfig *tailcfg.DNSConfig
}
// Headscale represents the base app of the service
@@ -58,7 +60,10 @@ type Headscale struct {
aclPolicy *ACLPolicy
aclRules *[]tailcfg.FilterRule
clientsPolling sync.Map
clientsUpdateChannels sync.Map
clientsUpdateChannelMutex sync.Mutex
lastStateChange sync.Map
}
// NewHeadscale returns the Headscale app
@@ -165,9 +170,18 @@ func (h *Headscale) Serve() error {
r.POST("/machine/:id", h.RegistrationHandler)
var err error
timeout := 30 * time.Second
go h.watchForKVUpdates(5000)
go h.expireEphemeralNodes(5000)
s := &http.Server{
Addr: h.cfg.Addr,
Handler: r,
ReadTimeout: timeout,
WriteTimeout: timeout,
}
if h.cfg.TLSLetsEncryptHostname != "" {
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
@@ -179,9 +193,11 @@ func (h *Headscale) Serve() error {
Cache: autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir),
}
s := &http.Server{
Addr: h.cfg.Addr,
TLSConfig: m.TLSConfig(),
Handler: r,
Addr: h.cfg.Addr,
TLSConfig: m.TLSConfig(),
Handler: r,
ReadTimeout: timeout,
WriteTimeout: timeout,
}
if h.cfg.TLSLetsEncryptChallengeType == "TLS-ALPN-01" {
// Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737)
@@ -206,12 +222,29 @@ func (h *Headscale) Serve() error {
if !strings.HasPrefix(h.cfg.ServerURL, "http://") {
log.Warn().Msg("Listening without TLS but ServerURL does not start with http://")
}
err = r.Run(h.cfg.Addr)
err = s.ListenAndServe()
} else {
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
}
err = r.RunTLS(h.cfg.Addr, h.cfg.TLSCertPath, h.cfg.TLSKeyPath)
err = s.ListenAndServeTLS(h.cfg.TLSCertPath, h.cfg.TLSKeyPath)
}
return err
}
func (h *Headscale) setLastStateChangeToNow(namespace string) {
now := time.Now().UTC()
h.lastStateChange.Store(namespace, now)
}
func (h *Headscale) getLastStateChange(namespace string) time.Time {
if wrapped, ok := h.lastStateChange.Load(namespace); ok {
lastChange, _ := wrapped.(time.Time)
return lastChange
}
now := time.Now().UTC()
h.lastStateChange.Store(namespace, now)
return now
}

View File

@@ -25,6 +25,7 @@ func init() {
nodeCmd.AddCommand(listNodesCmd)
nodeCmd.AddCommand(registerNodeCmd)
nodeCmd.AddCommand(deleteNodeCmd)
nodeCmd.AddCommand(shareMachineCmd)
}
var nodeCmd = &cobra.Command{
@@ -79,9 +80,26 @@ var listNodesCmd = &cobra.Command{
if err != nil {
log.Fatalf("Error initializing: %s", err)
}
namespace, err := h.GetNamespace(n)
if err != nil {
log.Fatalf("Error fetching namespace: %s", err)
}
machines, err := h.ListMachinesInNamespace(n)
if err != nil {
log.Fatalf("Error fetching machines: %s", err)
}
sharedMachines, err := h.ListSharedMachinesInNamespace(n)
if err != nil {
log.Fatalf("Error fetching shared machines: %s", err)
}
allMachines := append(*machines, *sharedMachines...)
if strings.HasPrefix(o, "json") {
JsonOutput(machines, err, o)
JsonOutput(allMachines, err, o)
return
}
@@ -89,7 +107,7 @@ var listNodesCmd = &cobra.Command{
log.Fatalf("Error getting nodes: %s", err)
}
d, err := nodesToPtables(*machines)
d, err := nodesToPtables(*namespace, allMachines)
if err != nil {
log.Fatalf("Error converting to table: %s", err)
}
@@ -145,31 +163,94 @@ var deleteNodeCmd = &cobra.Command{
},
}
func nodesToPtables(m []headscale.Machine) (pterm.TableData, error) {
d := pterm.TableData{{"ID", "Name", "NodeKey", "IP address", "Ephemeral", "Last seen", "Online"}}
var shareMachineCmd = &cobra.Command{
Use: "share ID namespace",
Short: "Shares a node from the current namespace to the specified one",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 2 {
return fmt.Errorf("missing parameters")
}
return nil
},
Run: func(cmd *cobra.Command, args []string) {
namespace, err := cmd.Flags().GetString("namespace")
if err != nil {
log.Fatalf("Error getting namespace: %s", err)
}
output, _ := cmd.Flags().GetString("output")
for _, m := range m {
h, err := getHeadscaleApp()
if err != nil {
log.Fatalf("Error initializing: %s", err)
}
_, err = h.GetNamespace(namespace)
if err != nil {
log.Fatalf("Error fetching origin namespace: %s", err)
}
destinationNamespace, err := h.GetNamespace(args[1])
if err != nil {
log.Fatalf("Error fetching destination namespace: %s", err)
}
id, err := strconv.Atoi(args[0])
if err != nil {
log.Fatalf("Error converting ID to integer: %s", err)
}
machine, err := h.GetMachineByID(uint64(id))
if err != nil {
log.Fatalf("Error getting node: %s", err)
}
err = h.AddSharedMachineToNamespace(machine, destinationNamespace)
if strings.HasPrefix(output, "json") {
JsonOutput(map[string]string{"Result": "Node shared"}, err, output)
return
}
if err != nil {
fmt.Printf("Error sharing node: %s\n", err)
return
}
fmt.Println("Node shared!")
},
}
func nodesToPtables(currentNamespace headscale.Namespace, machines []headscale.Machine) (pterm.TableData, error) {
d := pterm.TableData{{"ID", "Name", "NodeKey", "Namespace", "IP address", "Ephemeral", "Last seen", "Online"}}
for _, machine := range machines {
var ephemeral bool
if m.AuthKey != nil && m.AuthKey.Ephemeral {
if machine.AuthKey != nil && machine.AuthKey.Ephemeral {
ephemeral = true
}
var lastSeen time.Time
if m.LastSeen != nil {
lastSeen = *m.LastSeen
var lastSeenTime string
if machine.LastSeen != nil {
lastSeen = *machine.LastSeen
lastSeenTime = lastSeen.Format("2006-01-02 15:04:05")
}
nKey, err := wgkey.ParseHex(m.NodeKey)
nKey, err := wgkey.ParseHex(machine.NodeKey)
if err != nil {
return nil, err
}
nodeKey := tailcfg.NodeKey(nKey)
var online string
if m.LastSeen.After(time.Now().Add(-5 * time.Minute)) { // TODO: Find a better way to reliably show if online
if lastSeen.After(time.Now().Add(-5 * time.Minute)) { // TODO: Find a better way to reliably show if online
online = pterm.LightGreen("true")
} else {
online = pterm.LightRed("false")
}
d = append(d, []string{strconv.FormatUint(m.ID, 10), m.Name, nodeKey.ShortString(), m.IPAddress, strconv.FormatBool(ephemeral), lastSeen.Format("2006-01-02 15:04:05"), online})
var namespace string
if currentNamespace.ID == machine.NamespaceID {
namespace = pterm.LightMagenta(machine.Namespace.Name)
} else {
namespace = pterm.LightYellow(machine.Namespace.Name)
}
d = append(d, []string{strconv.FormatUint(machine.ID, 10), machine.Name, nodeKey.ShortString(), namespace, machine.IPAddress, strconv.FormatBool(ephemeral), lastSeenTime, online})
}
return d, nil
}

View File

@@ -5,6 +5,7 @@ import (
"log"
"strings"
"github.com/pterm/pterm"
"github.com/spf13/cobra"
)
@@ -15,6 +16,9 @@ func init() {
if err != nil {
log.Fatalf(err.Error())
}
enableRouteCmd.Flags().BoolP("all", "a", false, "Enable all routes advertised by the node")
routesCmd.AddCommand(listRoutesCmd)
routesCmd.AddCommand(enableRouteCmd)
}
@@ -44,19 +48,25 @@ var listRoutesCmd = &cobra.Command{
if err != nil {
log.Fatalf("Error initializing: %s", err)
}
routes, err := h.GetNodeRoutes(n, args[0])
if strings.HasPrefix(o, "json") {
JsonOutput(routes, err, o)
return
}
availableRoutes, err := h.GetAdvertisedNodeRoutes(n, args[0])
if err != nil {
fmt.Println(err)
return
}
fmt.Println(routes)
if strings.HasPrefix(o, "json") {
// TODO: Add enable/disabled information to this interface
JsonOutput(availableRoutes, err, o)
return
}
d := h.RoutesToPtables(n, args[0], *availableRoutes)
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
if err != nil {
log.Fatal(err)
}
},
}
@@ -64,32 +74,74 @@ var enableRouteCmd = &cobra.Command{
Use: "enable node-name route",
Short: "Allows exposing a route declared by this node to the rest of the nodes",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 2 {
return fmt.Errorf("Missing parameters")
all, err := cmd.Flags().GetBool("all")
if err != nil {
log.Fatalf("Error getting namespace: %s", err)
}
if all {
if len(args) < 1 {
return fmt.Errorf("Missing parameters")
}
return nil
} else {
if len(args) < 2 {
return fmt.Errorf("Missing parameters")
}
return nil
}
return nil
},
Run: func(cmd *cobra.Command, args []string) {
n, err := cmd.Flags().GetString("namespace")
if err != nil {
log.Fatalf("Error getting namespace: %s", err)
}
o, _ := cmd.Flags().GetString("output")
all, err := cmd.Flags().GetBool("all")
if err != nil {
log.Fatalf("Error getting namespace: %s", err)
}
h, err := getHeadscaleApp()
if err != nil {
log.Fatalf("Error initializing: %s", err)
}
route, err := h.EnableNodeRoute(n, args[0], args[1])
if strings.HasPrefix(o, "json") {
JsonOutput(route, err, o)
return
}
if err != nil {
fmt.Println(err)
return
if all {
availableRoutes, err := h.GetAdvertisedNodeRoutes(n, args[0])
if err != nil {
fmt.Println(err)
return
}
for _, availableRoute := range *availableRoutes {
err = h.EnableNodeRoute(n, args[0], availableRoute.String())
if err != nil {
fmt.Println(err)
return
}
if strings.HasPrefix(o, "json") {
JsonOutput(availableRoute, err, o)
} else {
fmt.Printf("Enabled route %s\n", availableRoute)
}
}
} else {
err = h.EnableNodeRoute(n, args[0], args[1])
if strings.HasPrefix(o, "json") {
JsonOutput(args[1], err, o)
return
}
if err != nil {
fmt.Println(err)
return
}
fmt.Printf("Enabled route %s\n", args[1])
}
fmt.Printf("Enabled route %s\n", route)
},
}

View File

@@ -39,7 +39,9 @@ func LoadConfig(path string) error {
viper.SetDefault("ip_prefix", "100.64.0.0/10")
viper.SetDefault("log_level", "debug")
viper.SetDefault("log_level", "info")
viper.SetDefault("dns_config", nil)
err := viper.ReadInConfig()
if err != nil {
@@ -70,6 +72,45 @@ func LoadConfig(path string) error {
} else {
return nil
}
}
func GetDNSConfig() *tailcfg.DNSConfig {
if viper.IsSet("dns_config") {
dnsConfig := &tailcfg.DNSConfig{}
if viper.IsSet("dns_config.nameservers") {
nameserversStr := viper.GetStringSlice("dns_config.nameservers")
nameservers := make([]netaddr.IP, len(nameserversStr))
resolvers := make([]tailcfg.DNSResolver, len(nameserversStr))
for index, nameserverStr := range nameserversStr {
nameserver, err := netaddr.ParseIP(nameserverStr)
if err != nil {
log.Error().
Str("func", "getDNSConfig").
Err(err).
Msgf("Could not parse nameserver IP: %s", nameserverStr)
}
nameservers[index] = nameserver
resolvers[index] = tailcfg.DNSResolver{
Addr: nameserver.String(),
}
}
dnsConfig.Nameservers = nameservers
dnsConfig.Resolvers = resolvers
}
if viper.IsSet("dns_config.domains") {
dnsConfig.Domains = viper.GetStringSlice("dns_config.domains")
}
return dnsConfig
}
return nil
}
func absPath(path string) string {
@@ -126,6 +167,8 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
TLSCertPath: absPath(viper.GetString("tls_cert_path")),
TLSKeyPath: absPath(viper.GetString("tls_key_path")),
DNSConfig: GetDNSConfig(),
}
h, err := headscale.NewHeadscale(cfg)

View File

@@ -58,7 +58,7 @@ func (*Suite) TestPostgresConfigLoading(c *check.C) {
c.Assert(viper.GetString("db_port"), check.Equals, "5432")
c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
c.Assert(viper.GetStringSlice("dns_config.nameservers")[0], check.Equals, "1.1.1.1")
}
func (*Suite) TestSqliteConfigLoading(c *check.C) {
@@ -92,6 +92,37 @@ func (*Suite) TestSqliteConfigLoading(c *check.C) {
c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
c.Assert(viper.GetStringSlice("dns_config.nameservers")[0], check.Equals, "1.1.1.1")
}
func (*Suite) TestDNSConfigLoading(c *check.C) {
tmpDir, err := ioutil.TempDir("", "headscale")
if err != nil {
c.Fatal(err)
}
defer os.RemoveAll(tmpDir)
path, err := os.Getwd()
if err != nil {
c.Fatal(err)
}
// Symlink the example config file
err = os.Symlink(filepath.Clean(path+"/../../config.json.sqlite.example"), filepath.Join(tmpDir, "config.json"))
if err != nil {
c.Fatal(err)
}
// Load example config, it should load without validation errors
err = cli.LoadConfig(tmpDir)
c.Assert(err, check.IsNil)
dnsConfig := cli.GetDNSConfig()
fmt.Println(dnsConfig)
c.Assert(dnsConfig.Nameservers[0].String(), check.Equals, "1.1.1.1")
c.Assert(dnsConfig.Resolvers[0].Addr, check.Equals, "1.1.1.1")
}
func writeConfig(c *check.C, tmpDir string, configYaml []byte) {

View File

@@ -16,5 +16,10 @@
"tls_letsencrypt_challenge_type": "HTTP-01",
"tls_cert_path": "",
"tls_key_path": "",
"acl_policy_path": ""
"acl_policy_path": "",
"dns_config": {
"nameservers": [
"1.1.1.1"
]
}
}

View File

@@ -12,5 +12,10 @@
"tls_letsencrypt_challenge_type": "HTTP-01",
"tls_cert_path": "",
"tls_key_path": "",
"acl_policy_path": ""
"acl_policy_path": "",
"dns_config": {
"nameservers": [
"1.1.1.1"
]
}
}

5
db.go
View File

@@ -44,6 +44,11 @@ func (h *Headscale) initDB() error {
return err
}
err = db.AutoMigrate(&SharedMachine{})
if err != nil {
return err
}
err = h.setValue("db_version", dbVersion)
return err
}

View File

@@ -1,7 +1,7 @@
# This file contains some of the official Tailscale DERP servers,
# shamelessly taken from https://github.com/tailscale/tailscale/blob/main/net/dnsfallback/dns-fallback-servers.json
#
# If you plan to somehow use headscale, please deploy your own DERP infra
# If you plan to somehow use headscale, please deploy your own DERP infra: https://tailscale.com/kb/1118/custom-derp-servers/
regions:
1:
regionid: 1

View File

@@ -4,10 +4,13 @@ package headscale
import (
"bytes"
"context"
"fmt"
"io/ioutil"
"log"
"net/http"
"os"
"path"
"strings"
"testing"
"time"
@@ -20,23 +23,48 @@ import (
"inet.af/netaddr"
)
type IntegrationTestSuite struct {
suite.Suite
}
func TestIntegrationTestSuite(t *testing.T) {
suite.Run(t, new(IntegrationTestSuite))
}
var integrationTmpDir string
var ih Headscale
var pool dockertest.Pool
var network dockertest.Network
var headscale dockertest.Resource
var tailscaleCount int = 5
var tailscaleCount int = 25
var tailscales map[string]dockertest.Resource
type IntegrationTestSuite struct {
suite.Suite
stats *suite.SuiteInformation
}
func TestIntegrationTestSuite(t *testing.T) {
s := new(IntegrationTestSuite)
suite.Run(t, s)
// HandleStats, which allows us to check if we passed and save logs
// is called after TearDown, so we cannot tear down containers before
// we have potentially saved the logs.
for _, tailscale := range tailscales {
if err := pool.Purge(&tailscale); err != nil {
log.Printf("Could not purge resource: %s\n", err)
}
}
if !s.stats.Passed() {
err := saveLog(&headscale, "test_output")
if err != nil {
log.Printf("Could not save log: %s\n", err)
}
}
if err := pool.Purge(&headscale); err != nil {
log.Printf("Could not purge resource: %s\n", err)
}
if err := network.Close(); err != nil {
log.Printf("Could not close network: %s\n", err)
}
}
func executeCommand(resource *dockertest.Resource, cmd []string) (string, error) {
var stdout bytes.Buffer
var stderr bytes.Buffer
@@ -62,6 +90,48 @@ func executeCommand(resource *dockertest.Resource, cmd []string) (string, error)
return stdout.String(), nil
}
func saveLog(resource *dockertest.Resource, basePath string) error {
err := os.MkdirAll(basePath, os.ModePerm)
if err != nil {
return err
}
var stdout bytes.Buffer
var stderr bytes.Buffer
err = pool.Client.Logs(
docker.LogsOptions{
Context: context.TODO(),
Container: resource.Container.ID,
OutputStream: &stdout,
ErrorStream: &stderr,
Tail: "all",
RawTerminal: false,
Stdout: true,
Stderr: true,
Follow: false,
Timestamps: false,
},
)
if err != nil {
return err
}
fmt.Printf("Saving logs for %s to %s\n", resource.Container.Name, basePath)
err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stdout.log"), []byte(stdout.String()), 0644)
if err != nil {
return err
}
err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stderr.log"), []byte(stdout.String()), 0644)
if err != nil {
return err
}
return nil
}
func dockerRestartPolicy(config *docker.HostConfig) {
// set AutoRemove to true so that stopped container goes away by itself
config.AutoRemove = true
@@ -115,7 +185,6 @@ func (s *IntegrationTestSuite) SetupSuite() {
PortBindings: map[docker.Port][]docker.PortBinding{
"8080/tcp": []docker.PortBinding{{HostPort: "8080"}},
},
Env: []string{},
}
fmt.Println("Creating headscale container")
@@ -134,7 +203,6 @@ func (s *IntegrationTestSuite) SetupSuite() {
Name: hostname,
Networks: []*dockertest.Network{&network},
Cmd: []string{"tailscaled", "--tun=userspace-networking", "--socks5-server=localhost:1055"},
Env: []string{},
}
if pts, err := pool.BuildAndRunWithBuildOptions(tailscaleBuildOptions, tailscaleOptions, dockerRestartPolicy); err == nil {
@@ -145,7 +213,6 @@ func (s *IntegrationTestSuite) SetupSuite() {
fmt.Printf("Created %s container\n", hostname)
}
// TODO: Replace this logic with something that can be detected on Github Actions
fmt.Println("Waiting for headscale to be ready")
hostEndpoint := fmt.Sprintf("localhost:%s", headscale.GetPort("8080/tcp"))
@@ -197,23 +264,14 @@ func (s *IntegrationTestSuite) SetupSuite() {
// The nodes need a bit of time to get their updated maps from headscale
// TODO: See if we can have a more deterministic wait here.
time.Sleep(20 * time.Second)
time.Sleep(60 * time.Second)
}
func (s *IntegrationTestSuite) TearDownSuite() {
if err := pool.Purge(&headscale); err != nil {
log.Printf("Could not purge resource: %s\n", err)
}
}
for _, tailscale := range tailscales {
if err := pool.Purge(&tailscale); err != nil {
log.Printf("Could not purge resource: %s\n", err)
}
}
if err := network.Close(); err != nil {
log.Printf("Could not close network: %s\n", err)
}
func (s *IntegrationTestSuite) HandleStats(suiteName string, stats *suite.SuiteInformation) {
s.stats = stats
}
func (s *IntegrationTestSuite) TestListNodes() {
@@ -295,7 +353,15 @@ func (s *IntegrationTestSuite) TestPingAllPeers() {
s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) {
// We currently cant ping ourselves, so skip that.
if peername != hostname {
command := []string{"tailscale", "ping", "--timeout=1s", "--c=1", ip.String()}
// We are only interested in "direct ping" which means what we
// might need a couple of more attempts before reaching the node.
command := []string{
"tailscale", "ping",
"--timeout=1s",
"--c=20",
"--until-direct=true",
ip.String(),
}
fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip)
result, err := executeCommand(

View File

@@ -7,5 +7,5 @@
"db_type": "sqlite3",
"db_path": "/tmp/integration_test_db.sqlite3",
"acl_policy_path": "",
"log_level": "trace"
"log_level": "debug"
}

View File

@@ -2,6 +2,7 @@ package headscale
import (
"encoding/json"
"errors"
"fmt"
"sort"
"strconv"
@@ -31,8 +32,9 @@ type Machine struct {
AuthKeyID uint
AuthKey *PreAuthKey
LastSeen *time.Time
Expiry *time.Time
LastSeen *time.Time
LastSuccessfulUpdate *time.Time
Expiry *time.Time
HostInfo datatypes.JSON
Endpoints datatypes.JSON
@@ -48,7 +50,9 @@ func (m Machine) isAlreadyRegistered() bool {
return m.Registered
}
func (m Machine) toNode() (*tailcfg.Node, error) {
// toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes
// as per the expected behaviour in the official SaaS
func (m Machine) toNode(includeRoutes bool) (*tailcfg.Node, error) {
nKey, err := wgkey.ParseHex(m.NodeKey)
if err != nil {
return nil, err
@@ -83,24 +87,26 @@ func (m Machine) toNode() (*tailcfg.Node, error) {
allowedIPs := []netaddr.IPPrefix{}
allowedIPs = append(allowedIPs, ip) // we append the node own IP, as it is required by the clients
routesStr := []string{}
if len(m.EnabledRoutes) != 0 {
allwIps, err := m.EnabledRoutes.MarshalJSON()
if err != nil {
return nil, err
if includeRoutes {
routesStr := []string{}
if len(m.EnabledRoutes) != 0 {
allwIps, err := m.EnabledRoutes.MarshalJSON()
if err != nil {
return nil, err
}
err = json.Unmarshal(allwIps, &routesStr)
if err != nil {
return nil, err
}
}
err = json.Unmarshal(allwIps, &routesStr)
if err != nil {
return nil, err
}
}
for _, aip := range routesStr {
ip, err := netaddr.ParseIPPrefix(aip)
if err != nil {
return nil, err
for _, routeStr := range routesStr {
ip, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return nil, err
}
allowedIPs = append(allowedIPs, ip)
}
allowedIPs = append(allowedIPs, ip)
}
endpoints := []string{}
@@ -134,13 +140,20 @@ func (m Machine) toNode() (*tailcfg.Node, error) {
derp = "127.3.3.40:0" // Zero means disconnected or unknown.
}
var keyExpiry time.Time
if m.Expiry != nil {
keyExpiry = *m.Expiry
} else {
keyExpiry = time.Time{}
}
n := tailcfg.Node{
ID: tailcfg.NodeID(m.ID), // this is the actual ID
StableID: tailcfg.StableNodeID(strconv.FormatUint(m.ID, 10)), // in headscale, unlike tailcontrol server, IDs are permanent
Name: hostinfo.Hostname,
User: tailcfg.UserID(m.NamespaceID),
Key: tailcfg.NodeKey(nKey),
KeyExpiry: *m.Expiry,
KeyExpiry: keyExpiry,
Machine: tailcfg.MachineKey(mKey),
DiscoKey: discoKey,
Addresses: addrs,
@@ -163,6 +176,7 @@ func (h *Headscale) getPeers(m Machine) (*[]*tailcfg.Node, error) {
Str("func", "getPeers").
Str("machine", m.Name).
Msg("Finding peers")
machines := []Machine{}
if err := h.db.Where("namespace_id = ? AND machine_key <> ? AND registered",
m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil {
@@ -170,9 +184,23 @@ func (h *Headscale) getPeers(m Machine) (*[]*tailcfg.Node, error) {
return nil, err
}
// We fetch here machines that are shared to the `Namespace` of the machine we are getting peers for
sharedMachines := []SharedMachine{}
if err := h.db.Preload("Namespace").Preload("Machine").Where("namespace_id = ?",
m.NamespaceID).Find(&sharedMachines).Error; err != nil {
return nil, err
}
peers := []*tailcfg.Node{}
for _, mn := range machines {
peer, err := mn.toNode()
peer, err := mn.toNode(true)
if err != nil {
return nil, err
}
peers = append(peers, peer)
}
for _, sharedMachine := range sharedMachines {
peer, err := sharedMachine.Machine.toNode(false) // shared nodes do not expose their routes
if err != nil {
return nil, err
}
@@ -199,18 +227,27 @@ func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error)
return &m, nil
}
}
return nil, fmt.Errorf("not found")
return nil, fmt.Errorf("machine not found")
}
// GetMachineByID finds a Machine by ID and returns the Machine struct
func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
m := Machine{}
if result := h.db.Find(&Machine{ID: id}).First(&m); result.Error != nil {
if result := h.db.Preload("Namespace").Find(&Machine{ID: id}).First(&m); result.Error != nil {
return nil, result.Error
}
return &m, nil
}
// UpdateMachine takes a Machine struct pointer (typically already loaded from database
// and updates it with the latest data from the database.
func (h *Headscale) UpdateMachine(m *Machine) error {
if result := h.db.Find(m).First(&m); result.Error != nil {
return result.Error
}
return nil
}
// DeleteMachine softs deletes a Machine from the database
func (h *Headscale) DeleteMachine(m *Machine) error {
m.Registered = false
@@ -249,23 +286,119 @@ func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
}
func (h *Headscale) notifyChangesToPeers(m *Machine) {
peers, _ := h.getPeers(*m)
peers, err := h.getPeers(*m)
if err != nil {
log.Error().
Str("func", "notifyChangesToPeers").
Str("machine", m.Name).
Msgf("Error getting peers: %s", err)
return
}
for _, p := range *peers {
pUp, ok := h.clientsPolling.Load(uint64(p.ID))
if ok {
log.Info().
Str("func", "notifyChangesToPeers").
Str("machine", m.Name).
Str("peer", p.Name).
Str("address", p.Addresses[0].String()).
Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0])
err := h.sendRequestOnUpdateChannel(p)
if err != nil {
log.Info().
Str("func", "notifyChangesToPeers").
Str("machine", m.Name).
Str("peer", m.Name).
Str("address", p.Addresses[0].String()).
Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0])
pUp.(chan []byte) <- []byte{}
} else {
log.Info().
Str("func", "notifyChangesToPeers").
Str("machine", m.Name).
Str("peer", m.Name).
Str("peer", p.Name).
Msgf("Peer %s does not appear to be polling", p.Name)
}
log.Trace().
Str("func", "notifyChangesToPeers").
Str("machine", m.Name).
Str("peer", p.Name).
Str("address", p.Addresses[0].String()).
Msgf("Notified peer %s (%s)", p.Name, p.Addresses[0])
}
}
func (h *Headscale) getOrOpenUpdateChannel(m *Machine) <-chan struct{} {
var updateChan chan struct{}
if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok {
if unwrapped, ok := storedChan.(chan struct{}); ok {
updateChan = unwrapped
} else {
log.Error().
Str("handler", "openUpdateChannel").
Str("machine", m.Name).
Msg("Failed to convert update channel to struct{}")
}
} else {
log.Debug().
Str("handler", "openUpdateChannel").
Str("machine", m.Name).
Msg("Update channel not found, creating")
updateChan = make(chan struct{})
h.clientsUpdateChannels.Store(m.ID, updateChan)
}
return updateChan
}
func (h *Headscale) closeUpdateChannel(m *Machine) {
h.clientsUpdateChannelMutex.Lock()
defer h.clientsUpdateChannelMutex.Unlock()
if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok {
if unwrapped, ok := storedChan.(chan struct{}); ok {
close(unwrapped)
}
}
h.clientsUpdateChannels.Delete(m.ID)
}
func (h *Headscale) sendRequestOnUpdateChannel(m *tailcfg.Node) error {
h.clientsUpdateChannelMutex.Lock()
defer h.clientsUpdateChannelMutex.Unlock()
pUp, ok := h.clientsUpdateChannels.Load(uint64(m.ID))
if ok {
log.Info().
Str("func", "requestUpdate").
Str("machine", m.Name).
Msgf("Notifying peer %s", m.Name)
if update, ok := pUp.(chan struct{}); ok {
log.Trace().
Str("func", "requestUpdate").
Str("machine", m.Name).
Msgf("Update channel is %#v", update)
update <- struct{}{}
log.Trace().
Str("func", "requestUpdate").
Str("machine", m.Name).
Msgf("Notified machine %s", m.Name)
}
} else {
log.Info().
Str("func", "requestUpdate").
Str("machine", m.Name).
Msgf("Machine %s does not appear to be polling", m.Name)
return errors.New("machine does not seem to be polling")
}
return nil
}
func (h *Headscale) isOutdated(m *Machine) bool {
err := h.UpdateMachine(m)
if err != nil {
return true
}
lastChange := h.getLastStateChange(m.Namespace.Name)
log.Trace().
Str("func", "keepAlive").
Str("machine", m.Name).
Time("last_successful_update", *m.LastSuccessfulUpdate).
Time("last_state_change", lastChange).
Msgf("Checking if %s is missing updates", m.Name)
return m.LastSuccessfulUpdate.Before(lastChange)
}

View File

@@ -91,12 +91,34 @@ func (h *Headscale) ListMachinesInNamespace(name string) (*[]Machine, error) {
}
machines := []Machine{}
if err := h.db.Preload("AuthKey").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil {
if err := h.db.Preload("AuthKey").Preload("Namespace").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil {
return nil, err
}
return &machines, nil
}
// ListSharedMachinesInNamespace returns all the machines that are shared to the specified namespace
func (h *Headscale) ListSharedMachinesInNamespace(name string) (*[]Machine, error) {
namespace, err := h.GetNamespace(name)
if err != nil {
return nil, err
}
sharedMachines := []SharedMachine{}
if err := h.db.Preload("Namespace").Where(&SharedMachine{NamespaceID: namespace.ID}).Find(&sharedMachines).Error; err != nil {
return nil, err
}
machines := []Machine{}
for _, sharedMachine := range sharedMachines {
machine, err := h.GetMachineByID(sharedMachine.MachineID) // otherwise not everything comes filled
if err != nil {
return nil, err
}
machines = append(machines, *machine)
}
return &machines, nil
}
// SetMachineNamespace assigns a Machine to a namespace
func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error {
n, err := h.GetNamespace(namespaceName)

404
poll.go
View File

@@ -1,38 +1,225 @@
package headscale
import (
"encoding/json"
"errors"
"io"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
"gorm.io/datatypes"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/wgkey"
)
// PollNetMapHandler takes care of /machine/:id/map
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
//
// The clients POST stuff like HostInfo and their Endpoints here, but
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) PollNetMapHandler(c *gin.Context) {
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Msg("PollNetMapHandler called")
body, _ := io.ReadAll(c.Request.Body)
mKeyStr := c.Param("id")
mKey, err := wgkey.ParseHex(mKeyStr)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot parse client key")
c.String(http.StatusBadRequest, "")
return
}
req := tailcfg.MapRequest{}
err = decode(body, &req, &mKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot decode message")
c.String(http.StatusBadRequest, "")
return
}
var m Machine
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString())
c.String(http.StatusUnauthorized, "")
return
}
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Found machine in database")
hostinfo, _ := json.Marshal(req.Hostinfo)
m.Name = req.Hostinfo.Hostname
m.HostInfo = datatypes.JSON(hostinfo)
m.DiscoKey = wgkey.Key(req.DiscoKey).HexString()
now := time.Now().UTC()
// From Tailscale client:
//
// ReadOnly is whether the client just wants to fetch the MapResponse,
// without updating their Endpoints. The Endpoints field will be ignored and
// LastSeen will not be updated and peers will not be notified of changes.
//
// The intended use is for clients to discover the DERP map at start-up
// before their first real endpoint update.
if !req.ReadOnly {
endpoints, _ := json.Marshal(req.Endpoints)
m.Endpoints = datatypes.JSON(endpoints)
m.LastSeen = &now
}
h.db.Save(&m)
data, err := h.getMapResponse(mKey, req, m)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Err(err).
Msg("Failed to get Map response")
c.String(http.StatusInternalServerError, ":(")
return
}
// We update our peers if the client is not sending ReadOnly in the MapRequest
// so we don't distribute its initial request (it comes with
// empty endpoints to peers)
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
log.Debug().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Bool("readOnly", req.ReadOnly).
Bool("omitPeers", req.OmitPeers).
Bool("stream", req.Stream).
Msg("Client map request processed")
if req.ReadOnly {
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client is starting up. Probably interested in a DERP map")
c.Data(200, "application/json; charset=utf-8", *data)
return
}
// There has been an update to _any_ of the nodes that the other nodes would
// need to know about
h.setLastStateChangeToNow(m.Namespace.Name)
// The request is not ReadOnly, so we need to set up channels for updating
// peers via longpoll
// Only create update channel if it has not been created
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Loading or creating update channel")
updateChan := h.getOrOpenUpdateChannel(&m)
pollDataChan := make(chan []byte)
// defer close(pollData)
keepAliveChan := make(chan []byte)
cancelKeepAlive := make(chan struct{})
defer close(cancelKeepAlive)
if req.OmitPeers && !req.Stream {
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client sent endpoint update and is ok with a response without peer list")
c.Data(200, "application/json; charset=utf-8", *data)
// It sounds like we should update the nodes when we have received a endpoint update
// even tho the comments in the tailscale code dont explicitly say so.
go h.notifyChangesToPeers(&m)
return
} else if req.OmitPeers && req.Stream {
log.Warn().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Ignoring request, don't know how to handle it")
c.String(http.StatusBadRequest, "")
return
}
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client is ready to access the tailnet")
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Sending initial map")
go func() { pollDataChan <- *data }()
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Notifying peers")
go h.notifyChangesToPeers(&m)
h.PollNetMapStream(c, m, req, mKey, pollDataChan, keepAliveChan, updateChan, cancelKeepAlive)
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Finished stream, closing PollNetMap session")
}
// PollNetMapStream takes care of /machine/:id/map
// stream logic, ensuring we communicate updates and data
// to the connected clients.
func (h *Headscale) PollNetMapStream(
c *gin.Context,
m Machine,
req tailcfg.MapRequest,
mKey wgkey.Key,
pollData chan []byte,
update chan []byte,
cancelKeepAlive chan []byte,
pollDataChan chan []byte,
keepAliveChan chan []byte,
updateChan <-chan struct{},
cancelKeepAlive chan struct{},
) {
go h.keepAlive(cancelKeepAlive, pollData, mKey, req, m)
go h.scheduledPollWorker(cancelKeepAlive, keepAliveChan, mKey, req, m)
c.Stream(func(w io.Writer) bool {
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Msg("Waiting for data to stream...")
select {
case data := <-pollData:
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan)
select {
case data := <-pollDataChan:
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Sending data received via pollData channel")
_, err := w.Write(data)
@@ -40,44 +227,148 @@ func (h *Headscale) PollNetMapStream(
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "pollData").
Err(err).
Msg("Cannot write data")
}
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Data from pollData channel written successfully")
// TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachine(&m)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "pollData").
Err(err).
Msg("Cannot update machine from database")
}
now := time.Now().UTC()
m.LastSeen = &now
m.LastSuccessfulUpdate = &now
h.db.Save(&m)
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Machine updated successfully after sending pollData")
return true
case data := <-keepAliveChan:
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Sending keep alive message")
_, err := w.Write(data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "keepAlive").
Err(err).
Msg("Cannot write keep alive message")
}
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Keep alive sent successfully")
// TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachine(&m)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "keepAlive").
Err(err).
Msg("Cannot update machine from database")
}
now := time.Now().UTC()
m.LastSeen = &now
h.db.Save(&m)
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Machine updated successfully after sending pollData")
Msg("Machine updated successfully after sending keep alive")
return true
case <-update:
log.Debug().
case <-updateChan:
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "update").
Msg("Received a request for update")
data, err := h.getMapResponse(mKey, req, m)
if err != nil {
log.Error().
if h.isOutdated(&m) {
log.Debug().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Err(err).
Msg("Could not get the map update")
}
_, err = w.Write(*data)
if err != nil {
log.Error().
Time("last_successful_update", *m.LastSuccessfulUpdate).
Time("last_state_change", h.getLastStateChange(m.Namespace.Name)).
Msgf("There has been updates since the last successful update to %s", m.Name)
data, err := h.getMapResponse(mKey, req, m)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "update").
Err(err).
Msg("Could not get the map update")
}
_, err = w.Write(*data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "update").
Err(err).
Msg("Could not write the map response")
}
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Err(err).
Msg("Could not write the map response")
Str("channel", "update").
Msg("Updated Map has been sent")
// Keep track of the last successful update,
// we sometimes end in a state were the update
// is not picked up by a client and we use this
// to determine if we should "force" an update.
// TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachine(&m)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "update").
Err(err).
Msg("Cannot update machine from database")
}
now := time.Now().UTC()
m.LastSuccessfulUpdate = &now
h.db.Save(&m)
} else {
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Time("last_successful_update", *m.LastSuccessfulUpdate).
Time("last_state_change", h.getLastStateChange(m.Namespace.Name)).
Msgf("%s is up to date", m.Name)
}
return true
@@ -86,13 +377,78 @@ func (h *Headscale) PollNetMapStream(
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Msg("The client has closed the connection")
// TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err := h.UpdateMachine(&m)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "Done").
Err(err).
Msg("Cannot update machine from database")
}
now := time.Now().UTC()
m.LastSeen = &now
h.db.Save(&m)
cancelKeepAlive <- []byte{}
h.clientsPolling.Delete(m.ID)
close(update)
cancelKeepAlive <- struct{}{}
h.closeUpdateChannel(&m)
close(pollDataChan)
close(keepAliveChan)
return false
}
})
}
func (h *Headscale) scheduledPollWorker(
cancelChan <-chan struct{},
keepAliveChan chan<- []byte,
mKey wgkey.Key,
req tailcfg.MapRequest,
m Machine,
) {
keepAliveTicker := time.NewTicker(60 * time.Second)
updateCheckerTicker := time.NewTicker(30 * time.Second)
for {
select {
case <-cancelChan:
return
case <-keepAliveTicker.C:
data, err := h.getMapKeepAliveResponse(mKey, req, m)
if err != nil {
log.Error().
Str("func", "keepAlive").
Err(err).
Msg("Error generating the keep alive msg")
return
}
log.Debug().
Str("func", "keepAlive").
Str("machine", m.Name).
Msg("Sending keepalive")
keepAliveChan <- *data
case <-updateCheckerTicker.C:
// Send an update request regardless of outdated or not, if data is sent
// to the node is determined in the updateChan consumer block
n, _ := m.toNode(true)
err := h.sendRequestOnUpdateChannel(n)
if err != nil {
log.Error().
Str("func", "keepAlive").
Str("machine", m.Name).
Err(err).
Msgf("Failed to send update request to %s", m.Name)
}
}
}
}

131
routes.go
View File

@@ -2,55 +2,142 @@ package headscale
import (
"encoding/json"
"errors"
"fmt"
"strconv"
"github.com/pterm/pterm"
"gorm.io/datatypes"
"inet.af/netaddr"
)
// GetNodeRoutes returns the subnet routes advertised by a node (identified by
// GetAdvertisedNodeRoutes returns the subnet routes advertised by a node (identified by
// namespace and node name)
func (h *Headscale) GetNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) {
func (h *Headscale) GetAdvertisedNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) {
m, err := h.GetMachine(namespace, nodeName)
if err != nil {
return nil, err
}
hi, err := m.GetHostInfo()
hostInfo, err := m.GetHostInfo()
if err != nil {
return nil, err
}
return &hi.RoutableIPs, nil
return &hostInfo.RoutableIPs, nil
}
// GetEnabledNodeRoutes returns the subnet routes enabled by a node (identified by
// namespace and node name)
func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]netaddr.IPPrefix, error) {
m, err := h.GetMachine(namespace, nodeName)
if err != nil {
return nil, err
}
data, err := m.EnabledRoutes.MarshalJSON()
if err != nil {
return nil, err
}
routesStr := []string{}
err = json.Unmarshal(data, &routesStr)
if err != nil {
return nil, err
}
routes := make([]netaddr.IPPrefix, len(routesStr))
for index, routeStr := range routesStr {
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return nil, err
}
routes[index] = route
}
return routes, nil
}
// IsNodeRouteEnabled checks if a certain route has been enabled
func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeStr string) bool {
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return false
}
enabledRoutes, err := h.GetEnabledNodeRoutes(namespace, nodeName)
if err != nil {
return false
}
for _, enabledRoute := range enabledRoutes {
if route == enabledRoute {
return true
}
}
return false
}
// EnableNodeRoute enables a subnet route advertised by a node (identified by
// namespace and node name)
func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) (*netaddr.IPPrefix, error) {
func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) error {
m, err := h.GetMachine(namespace, nodeName)
if err != nil {
return nil, err
}
hi, err := m.GetHostInfo()
if err != nil {
return nil, err
return err
}
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return nil, err
return err
}
for _, rIP := range hi.RoutableIPs {
if rIP == route {
routes, _ := json.Marshal([]string{routeStr}) // TODO: only one for the time being, so overwriting the rest
m.EnabledRoutes = datatypes.JSON(routes)
h.db.Save(&m)
availableRoutes, err := h.GetAdvertisedNodeRoutes(namespace, nodeName)
if err != nil {
return err
}
err = h.RequestMapUpdates(m.NamespaceID)
if err != nil {
return nil, err
enabledRoutes, err := h.GetEnabledNodeRoutes(namespace, nodeName)
if err != nil {
return err
}
available := false
for _, availableRoute := range *availableRoutes {
// If the route is available, and not yet enabled, add it to the new routing table
if route == availableRoute {
available = true
if !h.IsNodeRouteEnabled(namespace, nodeName, routeStr) {
enabledRoutes = append(enabledRoutes, route)
}
return &rIP, nil
}
}
return nil, errors.New("could not find routable range")
if !available {
return fmt.Errorf("route (%s) is not available on node %s", nodeName, routeStr)
}
routes, err := json.Marshal(enabledRoutes)
if err != nil {
return err
}
m.EnabledRoutes = datatypes.JSON(routes)
h.db.Save(&m)
err = h.RequestMapUpdates(m.NamespaceID)
if err != nil {
return err
}
return nil
}
// RoutesToPtables converts the list of routes to a nice table
func (h *Headscale) RoutesToPtables(namespace string, nodeName string, availableRoutes []netaddr.IPPrefix) pterm.TableData {
d := pterm.TableData{{"Route", "Enabled"}}
for _, route := range availableRoutes {
enabled := h.IsNodeRouteEnabled(namespace, nodeName, route.String())
d = append(d, []string{route.String(), strconv.FormatBool(enabled)})
}
return d
}

View File

@@ -16,7 +16,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "testmachine")
_, err = h.GetMachine("test", "test_get_route_machine")
c.Assert(err, check.NotNil)
route, err := netaddr.ParseIPPrefix("10.0.0.0/24")
@@ -33,7 +33,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
Name: "test_get_route_machine",
NamespaceID: n.ID,
Registered: true,
RegisterMethod: "authKey",
@@ -42,14 +42,87 @@ func (s *Suite) TestGetRoutes(c *check.C) {
}
h.db.Save(&m)
r, err := h.GetNodeRoutes("test", "testmachine")
r, err := h.GetAdvertisedNodeRoutes("test", "test_get_route_machine")
c.Assert(err, check.IsNil)
c.Assert(len(*r), check.Equals, 1)
_, err = h.EnableNodeRoute("test", "testmachine", "192.168.0.0/24")
err = h.EnableNodeRoute("test", "test_get_route_machine", "192.168.0.0/24")
c.Assert(err, check.NotNil)
_, err = h.EnableNodeRoute("test", "testmachine", "10.0.0.0/24")
err = h.EnableNodeRoute("test", "test_get_route_machine", "10.0.0.0/24")
c.Assert(err, check.IsNil)
}
func (s *Suite) TestGetEnableRoutes(c *check.C) {
n, err := h.CreateNamespace("test")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "test_enable_route_machine")
c.Assert(err, check.NotNil)
route, err := netaddr.ParseIPPrefix(
"10.0.0.0/24",
)
c.Assert(err, check.IsNil)
route2, err := netaddr.ParseIPPrefix(
"150.0.10.0/25",
)
c.Assert(err, check.IsNil)
hi := tailcfg.Hostinfo{
RoutableIPs: []netaddr.IPPrefix{route, route2},
}
hostinfo, err := json.Marshal(hi)
c.Assert(err, check.IsNil)
m := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "test_enable_route_machine",
NamespaceID: n.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
HostInfo: datatypes.JSON(hostinfo),
}
h.db.Save(&m)
availableRoutes, err := h.GetAdvertisedNodeRoutes("test", "test_enable_route_machine")
c.Assert(err, check.IsNil)
c.Assert(len(*availableRoutes), check.Equals, 2)
enabledRoutes, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes), check.Equals, 0)
err = h.EnableNodeRoute("test", "test_enable_route_machine", "192.168.0.0/24")
c.Assert(err, check.NotNil)
err = h.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
c.Assert(err, check.IsNil)
enabledRoutes1, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 1)
// Adding it twice will just let it pass through
err = h.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
c.Assert(err, check.IsNil)
enabledRoutes2, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes2), check.Equals, 1)
err = h.EnableNodeRoute("test", "test_enable_route_machine", "150.0.10.0/25")
c.Assert(err, check.IsNil)
enabledRoutes3, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes3), check.Equals, 2)
}

View File

@@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
set -e -o pipefail
commit="$1"

37
sharing.go Normal file
View File

@@ -0,0 +1,37 @@
package headscale
import "gorm.io/gorm"
const errorSameNamespace = Error("Destination namespace same as origin")
const errorMachineAlreadyShared = Error("Node already shared to this namespace")
// SharedMachine is a join table to support sharing nodes between namespaces
type SharedMachine struct {
gorm.Model
MachineID uint64
Machine Machine
NamespaceID uint
Namespace Namespace
}
// AddSharedMachineToNamespace adds a machine as a shared node to a namespace
func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error {
if m.NamespaceID == ns.ID {
return errorSameNamespace
}
sharedMachine := SharedMachine{}
if err := h.db.Where("machine_id = ? AND namespace_id", m.ID, ns.ID).First(&sharedMachine).Error; err == nil {
return errorMachineAlreadyShared
}
sharedMachine = SharedMachine{
MachineID: m.ID,
Machine: *m,
NamespaceID: ns.ID,
Namespace: *ns,
}
h.db.Save(&sharedMachine)
return nil
}

359
sharing_test.go Normal file
View File

@@ -0,0 +1,359 @@
package headscale
import (
"gopkg.in/check.v1"
"tailscale.com/tailcfg"
)
func (s *Suite) TestBasicSharedNodesInNamespace(c *check.C) {
n1, err := h.CreateNamespace("shared1")
c.Assert(err, check.IsNil)
n2, err := h.CreateNamespace("shared2")
c.Assert(err, check.IsNil)
pak1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
c.Assert(err, check.IsNil)
pak2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
c.Assert(err, check.NotNil)
m1 := Machine{
ID: 0,
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
Name: "test_get_shared_nodes_1",
NamespaceID: n1.ID,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.1",
AuthKeyID: uint(pak1.ID),
}
h.db.Save(&m1)
_, err = h.GetMachine(n1.Name, m1.Name)
c.Assert(err, check.IsNil)
m2 := Machine{
ID: 1,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_2",
NamespaceID: n2.ID,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.2",
AuthKeyID: uint(pak2.ID),
}
h.db.Save(&m2)
_, err = h.GetMachine(n2.Name, m2.Name)
c.Assert(err, check.IsNil)
p1s, err := h.getPeers(m1)
c.Assert(err, check.IsNil)
c.Assert(len(*p1s), check.Equals, 0)
err = h.AddSharedMachineToNamespace(&m2, n1)
c.Assert(err, check.IsNil)
p1sAfter, err := h.getPeers(m1)
c.Assert(err, check.IsNil)
c.Assert(len(*p1sAfter), check.Equals, 1)
c.Assert((*p1sAfter)[0].ID, check.Equals, tailcfg.NodeID(m2.ID))
}
func (s *Suite) TestSameNamespace(c *check.C) {
n1, err := h.CreateNamespace("shared1")
c.Assert(err, check.IsNil)
n2, err := h.CreateNamespace("shared2")
c.Assert(err, check.IsNil)
pak1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
c.Assert(err, check.IsNil)
pak2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
c.Assert(err, check.NotNil)
m1 := Machine{
ID: 0,
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
Name: "test_get_shared_nodes_1",
NamespaceID: n1.ID,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.1",
AuthKeyID: uint(pak1.ID),
}
h.db.Save(&m1)
_, err = h.GetMachine(n1.Name, m1.Name)
c.Assert(err, check.IsNil)
m2 := Machine{
ID: 1,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_2",
NamespaceID: n2.ID,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.2",
AuthKeyID: uint(pak2.ID),
}
h.db.Save(&m2)
_, err = h.GetMachine(n2.Name, m2.Name)
c.Assert(err, check.IsNil)
p1s, err := h.getPeers(m1)
c.Assert(err, check.IsNil)
c.Assert(len(*p1s), check.Equals, 0)
err = h.AddSharedMachineToNamespace(&m1, n1)
c.Assert(err, check.Equals, errorSameNamespace)
}
func (s *Suite) TestAlreadyShared(c *check.C) {
n1, err := h.CreateNamespace("shared1")
c.Assert(err, check.IsNil)
n2, err := h.CreateNamespace("shared2")
c.Assert(err, check.IsNil)
pak1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
c.Assert(err, check.IsNil)
pak2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
c.Assert(err, check.NotNil)
m1 := Machine{
ID: 0,
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
Name: "test_get_shared_nodes_1",
NamespaceID: n1.ID,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.1",
AuthKeyID: uint(pak1.ID),
}
h.db.Save(&m1)
_, err = h.GetMachine(n1.Name, m1.Name)
c.Assert(err, check.IsNil)
m2 := Machine{
ID: 1,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_2",
NamespaceID: n2.ID,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.2",
AuthKeyID: uint(pak2.ID),
}
h.db.Save(&m2)
_, err = h.GetMachine(n2.Name, m2.Name)
c.Assert(err, check.IsNil)
p1s, err := h.getPeers(m1)
c.Assert(err, check.IsNil)
c.Assert(len(*p1s), check.Equals, 0)
err = h.AddSharedMachineToNamespace(&m2, n1)
c.Assert(err, check.IsNil)
err = h.AddSharedMachineToNamespace(&m2, n1)
c.Assert(err, check.Equals, errorMachineAlreadyShared)
}
func (s *Suite) TestDoNotIncludeRoutesOnShared(c *check.C) {
n1, err := h.CreateNamespace("shared1")
c.Assert(err, check.IsNil)
n2, err := h.CreateNamespace("shared2")
c.Assert(err, check.IsNil)
pak1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
c.Assert(err, check.IsNil)
pak2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
c.Assert(err, check.NotNil)
m1 := Machine{
ID: 0,
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
Name: "test_get_shared_nodes_1",
NamespaceID: n1.ID,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.1",
AuthKeyID: uint(pak1.ID),
}
h.db.Save(&m1)
_, err = h.GetMachine(n1.Name, m1.Name)
c.Assert(err, check.IsNil)
m2 := Machine{
ID: 1,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_2",
NamespaceID: n2.ID,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.2",
AuthKeyID: uint(pak2.ID),
}
h.db.Save(&m2)
_, err = h.GetMachine(n2.Name, m2.Name)
c.Assert(err, check.IsNil)
p1s, err := h.getPeers(m1)
c.Assert(err, check.IsNil)
c.Assert(len(*p1s), check.Equals, 0)
err = h.AddSharedMachineToNamespace(&m2, n1)
c.Assert(err, check.IsNil)
p1sAfter, err := h.getPeers(m1)
c.Assert(err, check.IsNil)
c.Assert(len(*p1sAfter), check.Equals, 1)
c.Assert(len((*p1sAfter)[0].AllowedIPs), check.Equals, 1)
}
func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) {
n1, err := h.CreateNamespace("shared1")
c.Assert(err, check.IsNil)
n2, err := h.CreateNamespace("shared2")
c.Assert(err, check.IsNil)
n3, err := h.CreateNamespace("shared3")
c.Assert(err, check.IsNil)
pak1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
c.Assert(err, check.IsNil)
pak2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
c.Assert(err, check.IsNil)
pak3, err := h.CreatePreAuthKey(n3.Name, false, false, nil)
c.Assert(err, check.IsNil)
pak4, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
c.Assert(err, check.NotNil)
m1 := Machine{
ID: 0,
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
Name: "test_get_shared_nodes_1",
NamespaceID: n1.ID,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.1",
AuthKeyID: uint(pak1.ID),
}
h.db.Save(&m1)
_, err = h.GetMachine(n1.Name, m1.Name)
c.Assert(err, check.IsNil)
m2 := Machine{
ID: 1,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_2",
NamespaceID: n2.ID,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.2",
AuthKeyID: uint(pak2.ID),
}
h.db.Save(&m2)
_, err = h.GetMachine(n2.Name, m2.Name)
c.Assert(err, check.IsNil)
m3 := Machine{
ID: 2,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_3",
NamespaceID: n3.ID,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.3",
AuthKeyID: uint(pak3.ID),
}
h.db.Save(&m3)
_, err = h.GetMachine(n3.Name, m3.Name)
c.Assert(err, check.IsNil)
m4 := Machine{
ID: 3,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_4",
NamespaceID: n1.ID,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.4",
AuthKeyID: uint(pak4.ID),
}
h.db.Save(&m4)
_, err = h.GetMachine(n1.Name, m4.Name)
c.Assert(err, check.IsNil)
p1s, err := h.getPeers(m1)
c.Assert(err, check.IsNil)
c.Assert(len(*p1s), check.Equals, 1) // nodes 1 and 4
err = h.AddSharedMachineToNamespace(&m2, n1)
c.Assert(err, check.IsNil)
p1sAfter, err := h.getPeers(m1)
c.Assert(err, check.IsNil)
c.Assert(len(*p1sAfter), check.Equals, 2) // nodes 1, 2, 4
pAlone, err := h.getPeers(m3)
c.Assert(err, check.IsNil)
c.Assert(len(*pAlone), check.Equals, 0) // node 3 is alone
}