Compare commits

..

36 Commits

Author SHA1 Message Date
Kristoffer Dalby
123f0fa185 Merge pull request #98 from kradalby/initial-dns-server-exit-node 2021-08-25 22:58:25 +01:00
Kristoffer Dalby
ba3dffecbf Update readme 2021-08-25 19:05:10 +01:00
Kristoffer Dalby
8735e5675c Add a test for the getdnsconfig function 2021-08-25 19:03:04 +01:00
Kristoffer Dalby
3f5e06a0f8 Dont add the portnumber to the ip 2021-08-25 18:43:13 +01:00
Juan Font
ba40a40b73 Merge pull request #96 from qbit/version_fix
Fix setting of version
2021-08-25 12:34:34 +02:00
Kristoffer Dalby
b3732e7fb9 Add nameserver as resolver aswell 2021-08-25 07:04:48 +01:00
Aaron Bieber
104776ee84 fix setting of version 2021-08-24 07:49:15 -06:00
Kristoffer Dalby
01e781e546 Pass DNSConfig to nodes in MapResponse 2021-08-24 07:11:45 +01:00
Kristoffer Dalby
e77c16b55a Add DNSConfig to example and setup test 2021-08-24 07:10:09 +01:00
Kristoffer Dalby
987bbee1db Add DNSConfig field to configuration 2021-08-24 07:09:47 +01:00
Juan Font
74d2fe1baa Merge pull request #84 from kradalby/integration-tests-ci
Improve logic to keep nodes up to date with the network state
2021-08-23 09:42:07 +02:00
Kristoffer Dalby
98e63d5561 Merge pull request #94 from kradalby/split-lint-test
Split lint and test CI files
2021-08-23 07:46:11 +01:00
Kristoffer Dalby
059f13fc9d Add missing comment for stream function 2021-08-23 07:38:14 +01:00
Kristoffer Dalby
ebd27b46af Add comment to updatemachine 2021-08-23 07:35:44 +01:00
Juan Font
ca8d814918 Merge pull request #92 from kradalby/enhance-route-command
Enhance route command with ptables and multiple routes
2021-08-22 12:48:30 +02:00
Kristoffer Dalby
0aeeaac361 Always load machine object from DB before save/modify
We are currently holding Machine objects in memory for a long time,
while waiting for stream/longpoll, this might make us end up with stale
objects, that we just call save on, potentially overwriting stuff in
the database.

A typical scenario would be someone changing something from the CLI,
e.g. enabling routes, which in turn is overwritten again by the stale
object in the longpolling function.

The code has been left with TODO's and a discussion is available in #93.
2021-08-21 16:52:19 +01:00
Kristoffer Dalby
28ed8a5742 Actually rename lint 2021-08-21 15:42:23 +01:00
Kristoffer Dalby
f749be1490 Split lint and test CI files
This commit splits the lint and test steps into two different jobs in
github actions.

Consider this a suggestion, the idea is that when we look at PRs we will
see explicitly which one of the two types of checks fails without having
to open Github actions.
2021-08-21 15:40:27 +01:00
Kristoffer Dalby
693bce1b10 Update test machine name properly 2021-08-21 15:35:26 +01:00
Kristoffer Dalby
4f97e077db Add --all flag to routes enable command to enable all advertised routes 2021-08-21 15:04:30 +01:00
Kristoffer Dalby
c883e79884 Enhance route command with ptables and multiple routes
This commit rewrites the `routes list` command to use ptables to present
a slightly nicer list, including a new field if the route is enabled or
not (which is quite useful).

In addition, it reworks the enable command to support enabling multiple
routes (not only one route as per removed TODO). This allows users to
actually take advantage of exit-nodes and subnet relays.
2021-08-21 14:49:46 +01:00
Kristoffer Dalby
a054e2514a Keep tailscale count at 25 in integration tests 2021-08-21 09:26:18 +01:00
Kristoffer Dalby
c49fe26da7 Code clean up, loglevel debug for integration tests 2021-08-21 09:15:16 +01:00
Kristoffer Dalby
d93a7f2e02 Make Info default log level 2021-08-20 17:15:07 +01:00
Kristoffer Dalby
88d7ac04bf Account for racecondition in deleting/closing update channel
This commit tries to address the possible raceondition  that can happen
if a client closes its connection after we have fetched it from the
syncmap before sending the message.

To try to avoid introducing new dead lock conditions, all messages sent
to updateChannel has been moved into a function, which handles the
locking (instead of calling it all over the place)

The same lock is used around the delete/close function.
2021-08-20 16:52:34 +01:00
Kristoffer Dalby
1f422af1c8 Save headscale logs if jobs fail 2021-08-20 16:50:55 +01:00
Kristoffer Dalby
53168d54d8 Make http timeout 30s instead of 10s 2021-08-19 22:29:03 +01:00
Kristoffer Dalby
b0ec945dbb Make lastStateChange namespaced 2021-08-19 18:19:26 +01:00
Kristoffer Dalby
48ef6e5a6f Rename keepAlive function, as it now does more things 2021-08-19 18:06:57 +01:00
Kristoffer Dalby
8d1adaaef3 Move isOutdated logic to updateChan consumation 2021-08-19 18:05:33 +01:00
Kristoffer Dalby
dd8c0d1e9e Move most "poll" functionality to poll.go
This function migrates more poll functions (including keepalive) to
poll.go to keep it somehow in the same file.

In addition it makes changes to improve the stability and ensure nodes
get the appropriate updates from the headscale control and are not left
in an inconsistent state.

Two new additions is:

omitpeers=true will now trigger an update if the clients are not already up
to date

keepalive has been extended with a timer that will check every 120s if
all nodes are up to date.
2021-08-18 23:24:22 +01:00
Kristoffer Dalby
57b79aa852 Set timeout, add lastupdate field
This commit makes two reasonably major changes:

Set a default timeout for the go HTTP server (which gin uses), which
allows us to actually have broken long poll sessions fail so we can have
the client re-establish them.
The current 10s number is chosen randomly and we need more testing to
ensure that the feature work as intended.

The second is adding a last updated field to keep track of the last time
we had an update that needs to be propagated to all of our
clients/nodes. This will be used to keep track of our machines and if
they are up to date or need us to push an update.
2021-08-18 23:21:11 +01:00
Kristoffer Dalby
2f883410d2 Add lastUpdate field to machine, function issue message on update channel
This commit adds a new field to machine, lastSuccessfulUpdate which
tracks when we last was able to send a proper mapupdate to the node. The
purpose of this is to be able to compare to a "global" last updated time
and determine if we need to send an update map request to a node.

In addition it allows us to create a scheduled check to see if all known
nodes are up to date.

Also, add a helper function to send a message to the update channel of a
machine.
2021-08-18 23:17:38 +01:00
Kristoffer Dalby
6fa61380b2 Up client count, make arguments more explicit and clean up unused assignments 2021-08-18 23:17:09 +01:00
Kristoffer Dalby
7d1a5c00a0 Try with longer timeout 2021-08-13 16:56:28 +01:00
Kristoffer Dalby
036061664e initial integration test file 2021-08-13 16:12:01 +01:00
20 changed files with 1070 additions and 353 deletions

39
.github/workflows/lint.yml vendored Normal file
View File

@@ -0,0 +1,39 @@
name: CI
on: [push, pull_request]
jobs:
# The "build" workflow
lint:
# The type of runner that the job will run on
runs-on: ubuntu-latest
# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2
# Install and run golangci-lint as a separate step, it's much faster this
# way because this action has caching. It'll get run again in `make lint`
# below, but it's still much faster in the end than installing
# golangci-lint manually in the `Run lint` step.
- uses: golangci/golangci-lint-action@v2
with:
args: --timeout 2m
# Setup Go
- name: Setup Go
uses: actions/setup-go@v2
with:
go-version: "1.16.3" # The Go version to download (if necessary) and use.
# Install all the dependencies
- name: Install dependencies
run: |
go version
go install golang.org/x/lint/golint@latest
sudo apt update
sudo apt install -y make
- name: Run lint
run: make lint

23
.github/workflows/test-integration.yml vendored Normal file
View File

@@ -0,0 +1,23 @@
name: CI
on: [pull_request]
jobs:
# The "build" workflow
integration-test:
# The type of runner that the job will run on
runs-on: ubuntu-latest
# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2
# Setup Go
- name: Setup Go
uses: actions/setup-go@v2
with:
go-version: "1.16.3"
- name: Run Integration tests
run: go test -tags integration -timeout 30m

View File

@@ -10,36 +10,24 @@ jobs:
# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2
# Install and run golangci-lint as a separate step, it's much faster this
# way because this action has caching. It'll get run again in `make lint`
# below, but it's still much faster in the end than installing
# golangci-lint manually in the `Run lint` step.
- uses: golangci/golangci-lint-action@v2
with:
args: --timeout 2m
# Setup Go
- name: Setup Go
uses: actions/setup-go@v2
with:
go-version: '1.16.3' # The Go version to download (if necessary) and use.
# Setup Go
- name: Setup Go
uses: actions/setup-go@v2
with:
go-version: "1.16.3" # The Go version to download (if necessary) and use.
# Install all the dependencies
- name: Install dependencies
run: |
go version
go install golang.org/x/lint/golint@latest
sudo apt update
sudo apt install -y make
- name: Run tests
run: make test
# Install all the dependencies
- name: Install dependencies
run: |
go version
sudo apt update
sudo apt install -y make
- name: Run lint
run: make lint
- name: Run tests
run: make test
- name: Run build
run: make
- name: Run build
run: make

2
.gitignore vendored
View File

@@ -19,3 +19,5 @@ config.json
*.key
/db.sqlite
*.sqlite3
test_output/

View File

@@ -2,7 +2,7 @@
version = $(shell ./scripts/version-at-commit.sh)
build:
go build -ldflags "-s -w -X main.version=$(version)" cmd/headscale/headscale.go
go build -ldflags "-s -w -X github.com/juanfont/headscale/cmd/headscale/cli.version=$(version)" cmd/headscale/headscale.go
dev: lint test build

View File

@@ -25,8 +25,9 @@ Headscale implements this coordination server.
- [X] JSON-formatted output
- [X] ACLs
- [X] Support for alternative IP ranges in the tailnets (default Tailscale's 100.64.0.0/10)
- [X] DNS (passing DNS servers to nodes)
- [ ] Share nodes between ~~users~~ namespaces
- [ ] DNS
- [ ] MagicDNS / Smart DNS
## Roadmap 🤷

217
api.go
View File

@@ -13,9 +13,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
"gorm.io/datatypes"
"gorm.io/gorm"
"inet.af/netaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/wgkey"
)
@@ -82,14 +80,16 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
return
}
now := time.Now().UTC()
var m Machine
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
m = Machine{
Expiry: &req.Expiry,
MachineKey: mKey.HexString(),
Name: req.Hostinfo.Hostname,
NodeKey: wgkey.Key(req.NodeKey).HexString(),
Expiry: &req.Expiry,
MachineKey: mKey.HexString(),
Name: req.Hostinfo.Hostname,
NodeKey: wgkey.Key(req.NodeKey).HexString(),
LastSuccessfulUpdate: &now,
}
if err := h.db.Create(&m).Error; err != nil {
log.Error().
@@ -215,196 +215,6 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
c.Data(200, "application/json; charset=utf-8", respBody)
}
// PollNetMapHandler takes care of /machine/:id/map
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
//
// The clients POST stuff like HostInfo and their Endpoints here, but
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) PollNetMapHandler(c *gin.Context) {
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Msg("PollNetMapHandler called")
body, _ := io.ReadAll(c.Request.Body)
mKeyStr := c.Param("id")
mKey, err := wgkey.ParseHex(mKeyStr)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot parse client key")
c.String(http.StatusBadRequest, "")
return
}
req := tailcfg.MapRequest{}
err = decode(body, &req, &mKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot decode message")
c.String(http.StatusBadRequest, "")
return
}
var m Machine
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString())
c.String(http.StatusUnauthorized, "")
return
}
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Found machine in database")
hostinfo, _ := json.Marshal(req.Hostinfo)
m.Name = req.Hostinfo.Hostname
m.HostInfo = datatypes.JSON(hostinfo)
m.DiscoKey = wgkey.Key(req.DiscoKey).HexString()
now := time.Now().UTC()
// From Tailscale client:
//
// ReadOnly is whether the client just wants to fetch the MapResponse,
// without updating their Endpoints. The Endpoints field will be ignored and
// LastSeen will not be updated and peers will not be notified of changes.
//
// The intended use is for clients to discover the DERP map at start-up
// before their first real endpoint update.
if !req.ReadOnly {
endpoints, _ := json.Marshal(req.Endpoints)
m.Endpoints = datatypes.JSON(endpoints)
m.LastSeen = &now
}
h.db.Save(&m)
data, err := h.getMapResponse(mKey, req, m)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Err(err).
Msg("Failed to get Map response")
c.String(http.StatusInternalServerError, ":(")
return
}
// We update our peers if the client is not sending ReadOnly in the MapRequest
// so we don't distribute its initial request (it comes with
// empty endpoints to peers)
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
log.Debug().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Bool("readOnly", req.ReadOnly).
Bool("omitPeers", req.OmitPeers).
Bool("stream", req.Stream).
Msg("Client map request processed")
if req.ReadOnly {
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client is starting up. Asking for DERP map")
c.Data(200, "application/json; charset=utf-8", *data)
return
}
if req.OmitPeers && !req.Stream {
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client sent endpoint update and is ok with a response without peer list")
c.Data(200, "application/json; charset=utf-8", *data)
return
} else if req.OmitPeers && req.Stream {
log.Warn().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Ignoring request, don't know how to handle it")
c.String(http.StatusBadRequest, "")
return
}
// Only create update channel if it has not been created
var update chan []byte
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Creating or loading update channel")
if result, ok := h.clientsPolling.LoadOrStore(m.ID, make(chan []byte, 1)); ok {
update = result.(chan []byte)
}
pollData := make(chan []byte, 1)
defer close(pollData)
cancelKeepAlive := make(chan []byte, 1)
defer close(cancelKeepAlive)
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client is ready to access the tailnet")
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Sending initial map")
pollData <- *data
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Notifying peers")
// TODO: Why does this block?
go h.notifyChangesToPeers(&m)
h.PollNetMapStream(c, m, req, mKey, pollData, update, cancelKeepAlive)
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Finished stream, closing PollNetMap session")
}
func (h *Headscale) keepAlive(cancel chan []byte, pollData chan []byte, mKey wgkey.Key, req tailcfg.MapRequest, m Machine) {
for {
select {
case <-cancel:
return
default:
data, err := h.getMapKeepAliveResponse(mKey, req, m)
if err != nil {
log.Error().
Str("func", "keepAlive").
Err(err).
Msg("Error generating the keep alive msg")
return
}
log.Debug().
Str("func", "keepAlive").
Str("machine", m.Name).
Msg("Sending keepalive")
pollData <- *data
time.Sleep(60 * time.Second)
}
}
}
func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m Machine) (*[]byte, error) {
log.Trace().
Str("func", "getMapResponse").
@@ -434,10 +244,15 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m Mac
}
resp := tailcfg.MapResponse{
KeepAlive: false,
Node: node,
Peers: *peers,
DNS: []netaddr.IP{},
KeepAlive: false,
Node: node,
Peers: *peers,
//TODO(kradalby): As per tailscale docs, if DNSConfig is nil,
// it means its not updated, maybe we can have some logic
// to check and only pass updates when its updates.
// This is probably more relevant if we try to implement
// "MagicDNS"
DNSConfig: h.cfg.DNSConfig,
SearchPaths: []string{},
Domain: "headscale.net",
PacketFilter: *h.aclRules,
@@ -542,7 +357,7 @@ func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key,
Str("func", "handleAuthKey").
Str("machine", m.Name).
Str("ip", ip.String()).
Msgf("Assining %s to %s", ip, m.Name)
Msgf("Assigning %s to %s", ip, m.Name)
m.AuthKeyID = uint(pak.ID)
m.IPAddress = ip.String()

45
app.go
View File

@@ -43,6 +43,8 @@ type Config struct {
TLSCertPath string
TLSKeyPath string
DNSConfig *tailcfg.DNSConfig
}
// Headscale represents the base app of the service
@@ -58,7 +60,10 @@ type Headscale struct {
aclPolicy *ACLPolicy
aclRules *[]tailcfg.FilterRule
clientsPolling sync.Map
clientsUpdateChannels sync.Map
clientsUpdateChannelMutex sync.Mutex
lastStateChange sync.Map
}
// NewHeadscale returns the Headscale app
@@ -165,9 +170,18 @@ func (h *Headscale) Serve() error {
r.POST("/machine/:id", h.RegistrationHandler)
var err error
timeout := 30 * time.Second
go h.watchForKVUpdates(5000)
go h.expireEphemeralNodes(5000)
s := &http.Server{
Addr: h.cfg.Addr,
Handler: r,
ReadTimeout: timeout,
WriteTimeout: timeout,
}
if h.cfg.TLSLetsEncryptHostname != "" {
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
@@ -179,9 +193,11 @@ func (h *Headscale) Serve() error {
Cache: autocert.DirCache(h.cfg.TLSLetsEncryptCacheDir),
}
s := &http.Server{
Addr: h.cfg.Addr,
TLSConfig: m.TLSConfig(),
Handler: r,
Addr: h.cfg.Addr,
TLSConfig: m.TLSConfig(),
Handler: r,
ReadTimeout: timeout,
WriteTimeout: timeout,
}
if h.cfg.TLSLetsEncryptChallengeType == "TLS-ALPN-01" {
// Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737)
@@ -206,12 +222,29 @@ func (h *Headscale) Serve() error {
if !strings.HasPrefix(h.cfg.ServerURL, "http://") {
log.Warn().Msg("Listening without TLS but ServerURL does not start with http://")
}
err = r.Run(h.cfg.Addr)
err = s.ListenAndServe()
} else {
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
}
err = r.RunTLS(h.cfg.Addr, h.cfg.TLSCertPath, h.cfg.TLSKeyPath)
err = s.ListenAndServeTLS(h.cfg.TLSCertPath, h.cfg.TLSKeyPath)
}
return err
}
func (h *Headscale) setLastStateChangeToNow(namespace string) {
now := time.Now().UTC()
h.lastStateChange.Store(namespace, now)
}
func (h *Headscale) getLastStateChange(namespace string) time.Time {
if wrapped, ok := h.lastStateChange.Load(namespace); ok {
lastChange, _ := wrapped.(time.Time)
return lastChange
}
now := time.Now().UTC()
h.lastStateChange.Store(namespace, now)
return now
}

View File

@@ -5,6 +5,7 @@ import (
"log"
"strings"
"github.com/pterm/pterm"
"github.com/spf13/cobra"
)
@@ -15,6 +16,9 @@ func init() {
if err != nil {
log.Fatalf(err.Error())
}
enableRouteCmd.Flags().BoolP("all", "a", false, "Enable all routes advertised by the node")
routesCmd.AddCommand(listRoutesCmd)
routesCmd.AddCommand(enableRouteCmd)
}
@@ -44,19 +48,25 @@ var listRoutesCmd = &cobra.Command{
if err != nil {
log.Fatalf("Error initializing: %s", err)
}
routes, err := h.GetNodeRoutes(n, args[0])
if strings.HasPrefix(o, "json") {
JsonOutput(routes, err, o)
return
}
availableRoutes, err := h.GetAdvertisedNodeRoutes(n, args[0])
if err != nil {
fmt.Println(err)
return
}
fmt.Println(routes)
if strings.HasPrefix(o, "json") {
// TODO: Add enable/disabled information to this interface
JsonOutput(availableRoutes, err, o)
return
}
d := h.RoutesToPtables(n, args[0], *availableRoutes)
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
if err != nil {
log.Fatal(err)
}
},
}
@@ -64,32 +74,74 @@ var enableRouteCmd = &cobra.Command{
Use: "enable node-name route",
Short: "Allows exposing a route declared by this node to the rest of the nodes",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 2 {
return fmt.Errorf("Missing parameters")
all, err := cmd.Flags().GetBool("all")
if err != nil {
log.Fatalf("Error getting namespace: %s", err)
}
if all {
if len(args) < 1 {
return fmt.Errorf("Missing parameters")
}
return nil
} else {
if len(args) < 2 {
return fmt.Errorf("Missing parameters")
}
return nil
}
return nil
},
Run: func(cmd *cobra.Command, args []string) {
n, err := cmd.Flags().GetString("namespace")
if err != nil {
log.Fatalf("Error getting namespace: %s", err)
}
o, _ := cmd.Flags().GetString("output")
all, err := cmd.Flags().GetBool("all")
if err != nil {
log.Fatalf("Error getting namespace: %s", err)
}
h, err := getHeadscaleApp()
if err != nil {
log.Fatalf("Error initializing: %s", err)
}
route, err := h.EnableNodeRoute(n, args[0], args[1])
if strings.HasPrefix(o, "json") {
JsonOutput(route, err, o)
return
}
if err != nil {
fmt.Println(err)
return
if all {
availableRoutes, err := h.GetAdvertisedNodeRoutes(n, args[0])
if err != nil {
fmt.Println(err)
return
}
for _, availableRoute := range *availableRoutes {
err = h.EnableNodeRoute(n, args[0], availableRoute.String())
if err != nil {
fmt.Println(err)
return
}
if strings.HasPrefix(o, "json") {
JsonOutput(availableRoute, err, o)
} else {
fmt.Printf("Enabled route %s\n", availableRoute)
}
}
} else {
err = h.EnableNodeRoute(n, args[0], args[1])
if strings.HasPrefix(o, "json") {
JsonOutput(args[1], err, o)
return
}
if err != nil {
fmt.Println(err)
return
}
fmt.Printf("Enabled route %s\n", args[1])
}
fmt.Printf("Enabled route %s\n", route)
},
}

View File

@@ -39,7 +39,9 @@ func LoadConfig(path string) error {
viper.SetDefault("ip_prefix", "100.64.0.0/10")
viper.SetDefault("log_level", "debug")
viper.SetDefault("log_level", "info")
viper.SetDefault("dns_config", nil)
err := viper.ReadInConfig()
if err != nil {
@@ -70,6 +72,45 @@ func LoadConfig(path string) error {
} else {
return nil
}
}
func GetDNSConfig() *tailcfg.DNSConfig {
if viper.IsSet("dns_config") {
dnsConfig := &tailcfg.DNSConfig{}
if viper.IsSet("dns_config.nameservers") {
nameserversStr := viper.GetStringSlice("dns_config.nameservers")
nameservers := make([]netaddr.IP, len(nameserversStr))
resolvers := make([]tailcfg.DNSResolver, len(nameserversStr))
for index, nameserverStr := range nameserversStr {
nameserver, err := netaddr.ParseIP(nameserverStr)
if err != nil {
log.Error().
Str("func", "getDNSConfig").
Err(err).
Msgf("Could not parse nameserver IP: %s", nameserverStr)
}
nameservers[index] = nameserver
resolvers[index] = tailcfg.DNSResolver{
Addr: nameserver.String(),
}
}
dnsConfig.Nameservers = nameservers
dnsConfig.Resolvers = resolvers
}
if viper.IsSet("dns_config.domains") {
dnsConfig.Domains = viper.GetStringSlice("dns_config.domains")
}
return dnsConfig
}
return nil
}
func absPath(path string) string {
@@ -126,6 +167,8 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
TLSCertPath: absPath(viper.GetString("tls_cert_path")),
TLSKeyPath: absPath(viper.GetString("tls_key_path")),
DNSConfig: GetDNSConfig(),
}
h, err := headscale.NewHeadscale(cfg)

View File

@@ -58,7 +58,7 @@ func (*Suite) TestPostgresConfigLoading(c *check.C) {
c.Assert(viper.GetString("db_port"), check.Equals, "5432")
c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
c.Assert(viper.GetStringSlice("dns_config.nameservers")[0], check.Equals, "1.1.1.1")
}
func (*Suite) TestSqliteConfigLoading(c *check.C) {
@@ -92,6 +92,37 @@ func (*Suite) TestSqliteConfigLoading(c *check.C) {
c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
c.Assert(viper.GetStringSlice("dns_config.nameservers")[0], check.Equals, "1.1.1.1")
}
func (*Suite) TestDNSConfigLoading(c *check.C) {
tmpDir, err := ioutil.TempDir("", "headscale")
if err != nil {
c.Fatal(err)
}
defer os.RemoveAll(tmpDir)
path, err := os.Getwd()
if err != nil {
c.Fatal(err)
}
// Symlink the example config file
err = os.Symlink(filepath.Clean(path+"/../../config.json.sqlite.example"), filepath.Join(tmpDir, "config.json"))
if err != nil {
c.Fatal(err)
}
// Load example config, it should load without validation errors
err = cli.LoadConfig(tmpDir)
c.Assert(err, check.IsNil)
dnsConfig := cli.GetDNSConfig()
fmt.Println(dnsConfig)
c.Assert(dnsConfig.Nameservers[0].String(), check.Equals, "1.1.1.1")
c.Assert(dnsConfig.Resolvers[0].Addr, check.Equals, "1.1.1.1")
}
func writeConfig(c *check.C, tmpDir string, configYaml []byte) {

View File

@@ -16,5 +16,10 @@
"tls_letsencrypt_challenge_type": "HTTP-01",
"tls_cert_path": "",
"tls_key_path": "",
"acl_policy_path": ""
"acl_policy_path": "",
"dns_config": {
"nameservers": [
"1.1.1.1"
]
}
}

View File

@@ -12,5 +12,10 @@
"tls_letsencrypt_challenge_type": "HTTP-01",
"tls_cert_path": "",
"tls_key_path": "",
"acl_policy_path": ""
"acl_policy_path": "",
"dns_config": {
"nameservers": [
"1.1.1.1"
]
}
}

View File

@@ -4,10 +4,13 @@ package headscale
import (
"bytes"
"context"
"fmt"
"io/ioutil"
"log"
"net/http"
"os"
"path"
"strings"
"testing"
"time"
@@ -20,23 +23,48 @@ import (
"inet.af/netaddr"
)
type IntegrationTestSuite struct {
suite.Suite
}
func TestIntegrationTestSuite(t *testing.T) {
suite.Run(t, new(IntegrationTestSuite))
}
var integrationTmpDir string
var ih Headscale
var pool dockertest.Pool
var network dockertest.Network
var headscale dockertest.Resource
var tailscaleCount int = 5
var tailscaleCount int = 25
var tailscales map[string]dockertest.Resource
type IntegrationTestSuite struct {
suite.Suite
stats *suite.SuiteInformation
}
func TestIntegrationTestSuite(t *testing.T) {
s := new(IntegrationTestSuite)
suite.Run(t, s)
// HandleStats, which allows us to check if we passed and save logs
// is called after TearDown, so we cannot tear down containers before
// we have potentially saved the logs.
for _, tailscale := range tailscales {
if err := pool.Purge(&tailscale); err != nil {
log.Printf("Could not purge resource: %s\n", err)
}
}
if !s.stats.Passed() {
err := saveLog(&headscale, "test_output")
if err != nil {
log.Printf("Could not save log: %s\n", err)
}
}
if err := pool.Purge(&headscale); err != nil {
log.Printf("Could not purge resource: %s\n", err)
}
if err := network.Close(); err != nil {
log.Printf("Could not close network: %s\n", err)
}
}
func executeCommand(resource *dockertest.Resource, cmd []string) (string, error) {
var stdout bytes.Buffer
var stderr bytes.Buffer
@@ -62,6 +90,48 @@ func executeCommand(resource *dockertest.Resource, cmd []string) (string, error)
return stdout.String(), nil
}
func saveLog(resource *dockertest.Resource, basePath string) error {
err := os.MkdirAll(basePath, os.ModePerm)
if err != nil {
return err
}
var stdout bytes.Buffer
var stderr bytes.Buffer
err = pool.Client.Logs(
docker.LogsOptions{
Context: context.TODO(),
Container: resource.Container.ID,
OutputStream: &stdout,
ErrorStream: &stderr,
Tail: "all",
RawTerminal: false,
Stdout: true,
Stderr: true,
Follow: false,
Timestamps: false,
},
)
if err != nil {
return err
}
fmt.Printf("Saving logs for %s to %s\n", resource.Container.Name, basePath)
err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stdout.log"), []byte(stdout.String()), 0644)
if err != nil {
return err
}
err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stderr.log"), []byte(stdout.String()), 0644)
if err != nil {
return err
}
return nil
}
func dockerRestartPolicy(config *docker.HostConfig) {
// set AutoRemove to true so that stopped container goes away by itself
config.AutoRemove = true
@@ -115,7 +185,6 @@ func (s *IntegrationTestSuite) SetupSuite() {
PortBindings: map[docker.Port][]docker.PortBinding{
"8080/tcp": []docker.PortBinding{{HostPort: "8080"}},
},
Env: []string{},
}
fmt.Println("Creating headscale container")
@@ -134,7 +203,6 @@ func (s *IntegrationTestSuite) SetupSuite() {
Name: hostname,
Networks: []*dockertest.Network{&network},
Cmd: []string{"tailscaled", "--tun=userspace-networking", "--socks5-server=localhost:1055"},
Env: []string{},
}
if pts, err := pool.BuildAndRunWithBuildOptions(tailscaleBuildOptions, tailscaleOptions, dockerRestartPolicy); err == nil {
@@ -145,7 +213,6 @@ func (s *IntegrationTestSuite) SetupSuite() {
fmt.Printf("Created %s container\n", hostname)
}
// TODO: Replace this logic with something that can be detected on Github Actions
fmt.Println("Waiting for headscale to be ready")
hostEndpoint := fmt.Sprintf("localhost:%s", headscale.GetPort("8080/tcp"))
@@ -197,23 +264,14 @@ func (s *IntegrationTestSuite) SetupSuite() {
// The nodes need a bit of time to get their updated maps from headscale
// TODO: See if we can have a more deterministic wait here.
time.Sleep(20 * time.Second)
time.Sleep(60 * time.Second)
}
func (s *IntegrationTestSuite) TearDownSuite() {
if err := pool.Purge(&headscale); err != nil {
log.Printf("Could not purge resource: %s\n", err)
}
}
for _, tailscale := range tailscales {
if err := pool.Purge(&tailscale); err != nil {
log.Printf("Could not purge resource: %s\n", err)
}
}
if err := network.Close(); err != nil {
log.Printf("Could not close network: %s\n", err)
}
func (s *IntegrationTestSuite) HandleStats(suiteName string, stats *suite.SuiteInformation) {
s.stats = stats
}
func (s *IntegrationTestSuite) TestListNodes() {
@@ -295,7 +353,15 @@ func (s *IntegrationTestSuite) TestPingAllPeers() {
s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) {
// We currently cant ping ourselves, so skip that.
if peername != hostname {
command := []string{"tailscale", "ping", "--timeout=1s", "--c=1", ip.String()}
// We are only interested in "direct ping" which means what we
// might need a couple of more attempts before reaching the node.
command := []string{
"tailscale", "ping",
"--timeout=1s",
"--c=20",
"--until-direct=true",
ip.String(),
}
fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip)
result, err := executeCommand(

View File

@@ -7,5 +7,5 @@
"db_type": "sqlite3",
"db_path": "/tmp/integration_test_db.sqlite3",
"acl_policy_path": "",
"log_level": "trace"
"log_level": "debug"
}

View File

@@ -2,6 +2,7 @@ package headscale
import (
"encoding/json"
"errors"
"fmt"
"sort"
"strconv"
@@ -31,8 +32,9 @@ type Machine struct {
AuthKeyID uint
AuthKey *PreAuthKey
LastSeen *time.Time
Expiry *time.Time
LastSeen *time.Time
LastSuccessfulUpdate *time.Time
Expiry *time.Time
HostInfo datatypes.JSON
Endpoints datatypes.JSON
@@ -211,6 +213,15 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
return &m, nil
}
// UpdateMachine takes a Machine struct pointer (typically already loaded from database
// and updates it with the latest data from the database.
func (h *Headscale) UpdateMachine(m *Machine) error {
if result := h.db.Find(m).First(&m); result.Error != nil {
return result.Error
}
return nil
}
// DeleteMachine softs deletes a Machine from the database
func (h *Headscale) DeleteMachine(m *Machine) error {
m.Registered = false
@@ -251,21 +262,110 @@ func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
func (h *Headscale) notifyChangesToPeers(m *Machine) {
peers, _ := h.getPeers(*m)
for _, p := range *peers {
pUp, ok := h.clientsPolling.Load(uint64(p.ID))
if ok {
log.Info().
Str("func", "notifyChangesToPeers").
Str("machine", m.Name).
Str("peer", p.Name).
Str("address", p.Addresses[0].String()).
Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0])
err := h.sendRequestOnUpdateChannel(p)
if err != nil {
log.Info().
Str("func", "notifyChangesToPeers").
Str("machine", m.Name).
Str("peer", m.Name).
Str("address", p.Addresses[0].String()).
Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0])
pUp.(chan []byte) <- []byte{}
} else {
log.Info().
Str("func", "notifyChangesToPeers").
Str("machine", m.Name).
Str("peer", m.Name).
Str("peer", p.Name).
Msgf("Peer %s does not appear to be polling", p.Name)
}
log.Trace().
Str("func", "notifyChangesToPeers").
Str("machine", m.Name).
Str("peer", p.Name).
Str("address", p.Addresses[0].String()).
Msgf("Notified peer %s (%s)", p.Name, p.Addresses[0])
}
}
func (h *Headscale) getOrOpenUpdateChannel(m *Machine) <-chan struct{} {
var updateChan chan struct{}
if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok {
if unwrapped, ok := storedChan.(chan struct{}); ok {
updateChan = unwrapped
} else {
log.Error().
Str("handler", "openUpdateChannel").
Str("machine", m.Name).
Msg("Failed to convert update channel to struct{}")
}
} else {
log.Debug().
Str("handler", "openUpdateChannel").
Str("machine", m.Name).
Msg("Update channel not found, creating")
updateChan = make(chan struct{})
h.clientsUpdateChannels.Store(m.ID, updateChan)
}
return updateChan
}
func (h *Headscale) closeUpdateChannel(m *Machine) {
h.clientsUpdateChannelMutex.Lock()
defer h.clientsUpdateChannelMutex.Unlock()
if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok {
if unwrapped, ok := storedChan.(chan struct{}); ok {
close(unwrapped)
}
}
h.clientsUpdateChannels.Delete(m.ID)
}
func (h *Headscale) sendRequestOnUpdateChannel(m *tailcfg.Node) error {
h.clientsUpdateChannelMutex.Lock()
defer h.clientsUpdateChannelMutex.Unlock()
pUp, ok := h.clientsUpdateChannels.Load(uint64(m.ID))
if ok {
log.Info().
Str("func", "requestUpdate").
Str("machine", m.Name).
Msgf("Notifying peer %s", m.Name)
if update, ok := pUp.(chan struct{}); ok {
log.Trace().
Str("func", "requestUpdate").
Str("machine", m.Name).
Msgf("Update channel is %#v", update)
update <- struct{}{}
log.Trace().
Str("func", "requestUpdate").
Str("machine", m.Name).
Msgf("Notified machine %s", m.Name)
}
} else {
log.Info().
Str("func", "requestUpdate").
Str("machine", m.Name).
Msgf("Machine %s does not appear to be polling", m.Name)
return errors.New("machine does not seem to be polling")
}
return nil
}
func (h *Headscale) isOutdated(m *Machine) bool {
err := h.UpdateMachine(m)
if err != nil {
return true
}
lastChange := h.getLastStateChange(m.Namespace.Name)
log.Trace().
Str("func", "keepAlive").
Str("machine", m.Name).
Time("last_successful_update", *m.LastSuccessfulUpdate).
Time("last_state_change", lastChange).
Msgf("Checking if %s is missing updates", m.Name)
return m.LastSuccessfulUpdate.Before(lastChange)
}

404
poll.go
View File

@@ -1,38 +1,225 @@
package headscale
import (
"encoding/json"
"errors"
"io"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
"gorm.io/datatypes"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/wgkey"
)
// PollNetMapHandler takes care of /machine/:id/map
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
//
// The clients POST stuff like HostInfo and their Endpoints here, but
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) PollNetMapHandler(c *gin.Context) {
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Msg("PollNetMapHandler called")
body, _ := io.ReadAll(c.Request.Body)
mKeyStr := c.Param("id")
mKey, err := wgkey.ParseHex(mKeyStr)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot parse client key")
c.String(http.StatusBadRequest, "")
return
}
req := tailcfg.MapRequest{}
err = decode(body, &req, &mKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot decode message")
c.String(http.StatusBadRequest, "")
return
}
var m Machine
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString())
c.String(http.StatusUnauthorized, "")
return
}
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Found machine in database")
hostinfo, _ := json.Marshal(req.Hostinfo)
m.Name = req.Hostinfo.Hostname
m.HostInfo = datatypes.JSON(hostinfo)
m.DiscoKey = wgkey.Key(req.DiscoKey).HexString()
now := time.Now().UTC()
// From Tailscale client:
//
// ReadOnly is whether the client just wants to fetch the MapResponse,
// without updating their Endpoints. The Endpoints field will be ignored and
// LastSeen will not be updated and peers will not be notified of changes.
//
// The intended use is for clients to discover the DERP map at start-up
// before their first real endpoint update.
if !req.ReadOnly {
endpoints, _ := json.Marshal(req.Endpoints)
m.Endpoints = datatypes.JSON(endpoints)
m.LastSeen = &now
}
h.db.Save(&m)
data, err := h.getMapResponse(mKey, req, m)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Err(err).
Msg("Failed to get Map response")
c.String(http.StatusInternalServerError, ":(")
return
}
// We update our peers if the client is not sending ReadOnly in the MapRequest
// so we don't distribute its initial request (it comes with
// empty endpoints to peers)
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
log.Debug().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Bool("readOnly", req.ReadOnly).
Bool("omitPeers", req.OmitPeers).
Bool("stream", req.Stream).
Msg("Client map request processed")
if req.ReadOnly {
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client is starting up. Probably interested in a DERP map")
c.Data(200, "application/json; charset=utf-8", *data)
return
}
// There has been an update to _any_ of the nodes that the other nodes would
// need to know about
h.setLastStateChangeToNow(m.Namespace.Name)
// The request is not ReadOnly, so we need to set up channels for updating
// peers via longpoll
// Only create update channel if it has not been created
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Loading or creating update channel")
updateChan := h.getOrOpenUpdateChannel(&m)
pollDataChan := make(chan []byte)
// defer close(pollData)
keepAliveChan := make(chan []byte)
cancelKeepAlive := make(chan struct{})
defer close(cancelKeepAlive)
if req.OmitPeers && !req.Stream {
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client sent endpoint update and is ok with a response without peer list")
c.Data(200, "application/json; charset=utf-8", *data)
// It sounds like we should update the nodes when we have received a endpoint update
// even tho the comments in the tailscale code dont explicitly say so.
go h.notifyChangesToPeers(&m)
return
} else if req.OmitPeers && req.Stream {
log.Warn().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Ignoring request, don't know how to handle it")
c.String(http.StatusBadRequest, "")
return
}
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client is ready to access the tailnet")
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Sending initial map")
go func() { pollDataChan <- *data }()
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Notifying peers")
go h.notifyChangesToPeers(&m)
h.PollNetMapStream(c, m, req, mKey, pollDataChan, keepAliveChan, updateChan, cancelKeepAlive)
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Finished stream, closing PollNetMap session")
}
// PollNetMapStream takes care of /machine/:id/map
// stream logic, ensuring we communicate updates and data
// to the connected clients.
func (h *Headscale) PollNetMapStream(
c *gin.Context,
m Machine,
req tailcfg.MapRequest,
mKey wgkey.Key,
pollData chan []byte,
update chan []byte,
cancelKeepAlive chan []byte,
pollDataChan chan []byte,
keepAliveChan chan []byte,
updateChan <-chan struct{},
cancelKeepAlive chan struct{},
) {
go h.keepAlive(cancelKeepAlive, pollData, mKey, req, m)
go h.scheduledPollWorker(cancelKeepAlive, keepAliveChan, mKey, req, m)
c.Stream(func(w io.Writer) bool {
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Msg("Waiting for data to stream...")
select {
case data := <-pollData:
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan)
select {
case data := <-pollDataChan:
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Sending data received via pollData channel")
_, err := w.Write(data)
@@ -40,44 +227,148 @@ func (h *Headscale) PollNetMapStream(
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "pollData").
Err(err).
Msg("Cannot write data")
}
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Data from pollData channel written successfully")
// TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachine(&m)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "pollData").
Err(err).
Msg("Cannot update machine from database")
}
now := time.Now().UTC()
m.LastSeen = &now
m.LastSuccessfulUpdate = &now
h.db.Save(&m)
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Machine updated successfully after sending pollData")
return true
case data := <-keepAliveChan:
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Sending keep alive message")
_, err := w.Write(data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "keepAlive").
Err(err).
Msg("Cannot write keep alive message")
}
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Keep alive sent successfully")
// TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachine(&m)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "keepAlive").
Err(err).
Msg("Cannot update machine from database")
}
now := time.Now().UTC()
m.LastSeen = &now
h.db.Save(&m)
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Machine updated successfully after sending pollData")
Msg("Machine updated successfully after sending keep alive")
return true
case <-update:
log.Debug().
case <-updateChan:
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "update").
Msg("Received a request for update")
data, err := h.getMapResponse(mKey, req, m)
if err != nil {
log.Error().
if h.isOutdated(&m) {
log.Debug().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Err(err).
Msg("Could not get the map update")
}
_, err = w.Write(*data)
if err != nil {
log.Error().
Time("last_successful_update", *m.LastSuccessfulUpdate).
Time("last_state_change", h.getLastStateChange(m.Namespace.Name)).
Msgf("There has been updates since the last successful update to %s", m.Name)
data, err := h.getMapResponse(mKey, req, m)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "update").
Err(err).
Msg("Could not get the map update")
}
_, err = w.Write(*data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "update").
Err(err).
Msg("Could not write the map response")
}
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Err(err).
Msg("Could not write the map response")
Str("channel", "update").
Msg("Updated Map has been sent")
// Keep track of the last successful update,
// we sometimes end in a state were the update
// is not picked up by a client and we use this
// to determine if we should "force" an update.
// TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachine(&m)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "update").
Err(err).
Msg("Cannot update machine from database")
}
now := time.Now().UTC()
m.LastSuccessfulUpdate = &now
h.db.Save(&m)
} else {
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Time("last_successful_update", *m.LastSuccessfulUpdate).
Time("last_state_change", h.getLastStateChange(m.Namespace.Name)).
Msgf("%s is up to date", m.Name)
}
return true
@@ -86,13 +377,78 @@ func (h *Headscale) PollNetMapStream(
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Msg("The client has closed the connection")
// TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err := h.UpdateMachine(&m)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Str("machine", m.Name).
Str("channel", "Done").
Err(err).
Msg("Cannot update machine from database")
}
now := time.Now().UTC()
m.LastSeen = &now
h.db.Save(&m)
cancelKeepAlive <- []byte{}
h.clientsPolling.Delete(m.ID)
close(update)
cancelKeepAlive <- struct{}{}
h.closeUpdateChannel(&m)
close(pollDataChan)
close(keepAliveChan)
return false
}
})
}
func (h *Headscale) scheduledPollWorker(
cancelChan <-chan struct{},
keepAliveChan chan<- []byte,
mKey wgkey.Key,
req tailcfg.MapRequest,
m Machine,
) {
keepAliveTicker := time.NewTicker(60 * time.Second)
updateCheckerTicker := time.NewTicker(30 * time.Second)
for {
select {
case <-cancelChan:
return
case <-keepAliveTicker.C:
data, err := h.getMapKeepAliveResponse(mKey, req, m)
if err != nil {
log.Error().
Str("func", "keepAlive").
Err(err).
Msg("Error generating the keep alive msg")
return
}
log.Debug().
Str("func", "keepAlive").
Str("machine", m.Name).
Msg("Sending keepalive")
keepAliveChan <- *data
case <-updateCheckerTicker.C:
// Send an update request regardless of outdated or not, if data is sent
// to the node is determined in the updateChan consumer block
n, _ := m.toNode()
err := h.sendRequestOnUpdateChannel(n)
if err != nil {
log.Error().
Str("func", "keepAlive").
Str("machine", m.Name).
Err(err).
Msgf("Failed to send update request to %s", m.Name)
}
}
}
}

129
routes.go
View File

@@ -2,55 +2,140 @@ package headscale
import (
"encoding/json"
"errors"
"fmt"
"strconv"
"github.com/pterm/pterm"
"gorm.io/datatypes"
"inet.af/netaddr"
)
// GetNodeRoutes returns the subnet routes advertised by a node (identified by
// GetAdvertisedNodeRoutes returns the subnet routes advertised by a node (identified by
// namespace and node name)
func (h *Headscale) GetNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) {
func (h *Headscale) GetAdvertisedNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) {
m, err := h.GetMachine(namespace, nodeName)
if err != nil {
return nil, err
}
hi, err := m.GetHostInfo()
hostInfo, err := m.GetHostInfo()
if err != nil {
return nil, err
}
return &hi.RoutableIPs, nil
return &hostInfo.RoutableIPs, nil
}
// GetEnabledNodeRoutes returns the subnet routes enabled by a node (identified by
// namespace and node name)
func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]netaddr.IPPrefix, error) {
m, err := h.GetMachine(namespace, nodeName)
if err != nil {
return nil, err
}
data, err := m.EnabledRoutes.MarshalJSON()
if err != nil {
return nil, err
}
routesStr := []string{}
err = json.Unmarshal(data, &routesStr)
if err != nil {
return nil, err
}
routes := make([]netaddr.IPPrefix, len(routesStr))
for index, routeStr := range routesStr {
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return nil, err
}
routes[index] = route
}
return routes, nil
}
func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeStr string) bool {
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return false
}
enabledRoutes, err := h.GetEnabledNodeRoutes(namespace, nodeName)
if err != nil {
return false
}
for _, enabledRoute := range enabledRoutes {
if route == enabledRoute {
return true
}
}
return false
}
// EnableNodeRoute enables a subnet route advertised by a node (identified by
// namespace and node name)
func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) (*netaddr.IPPrefix, error) {
func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) error {
m, err := h.GetMachine(namespace, nodeName)
if err != nil {
return nil, err
}
hi, err := m.GetHostInfo()
if err != nil {
return nil, err
return err
}
route, err := netaddr.ParseIPPrefix(routeStr)
if err != nil {
return nil, err
return err
}
for _, rIP := range hi.RoutableIPs {
if rIP == route {
routes, _ := json.Marshal([]string{routeStr}) // TODO: only one for the time being, so overwriting the rest
m.EnabledRoutes = datatypes.JSON(routes)
h.db.Save(&m)
availableRoutes, err := h.GetAdvertisedNodeRoutes(namespace, nodeName)
if err != nil {
return err
}
err = h.RequestMapUpdates(m.NamespaceID)
if err != nil {
return nil, err
enabledRoutes, err := h.GetEnabledNodeRoutes(namespace, nodeName)
if err != nil {
return err
}
available := false
for _, availableRoute := range *availableRoutes {
// If the route is available, and not yet enabled, add it to the new routing table
if route == availableRoute {
available = true
if !h.IsNodeRouteEnabled(namespace, nodeName, routeStr) {
enabledRoutes = append(enabledRoutes, route)
}
return &rIP, nil
}
}
return nil, errors.New("could not find routable range")
if !available {
return fmt.Errorf("route (%s) is not available on node %s", nodeName, routeStr)
}
routes, err := json.Marshal(enabledRoutes)
if err != nil {
return err
}
m.EnabledRoutes = datatypes.JSON(routes)
h.db.Save(&m)
err = h.RequestMapUpdates(m.NamespaceID)
if err != nil {
return err
}
return nil
}
func (h *Headscale) RoutesToPtables(namespace string, nodeName string, availableRoutes []netaddr.IPPrefix) pterm.TableData {
d := pterm.TableData{{"Route", "Enabled"}}
for _, route := range availableRoutes {
enabled := h.IsNodeRouteEnabled(namespace, nodeName, route.String())
d = append(d, []string{route.String(), strconv.FormatBool(enabled)})
}
return d
}

View File

@@ -16,7 +16,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "testmachine")
_, err = h.GetMachine("test", "test_get_route_machine")
c.Assert(err, check.NotNil)
route, err := netaddr.ParseIPPrefix("10.0.0.0/24")
@@ -33,7 +33,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
Name: "test_get_route_machine",
NamespaceID: n.ID,
Registered: true,
RegisterMethod: "authKey",
@@ -42,14 +42,87 @@ func (s *Suite) TestGetRoutes(c *check.C) {
}
h.db.Save(&m)
r, err := h.GetNodeRoutes("test", "testmachine")
r, err := h.GetAdvertisedNodeRoutes("test", "test_get_route_machine")
c.Assert(err, check.IsNil)
c.Assert(len(*r), check.Equals, 1)
_, err = h.EnableNodeRoute("test", "testmachine", "192.168.0.0/24")
err = h.EnableNodeRoute("test", "test_get_route_machine", "192.168.0.0/24")
c.Assert(err, check.NotNil)
_, err = h.EnableNodeRoute("test", "testmachine", "10.0.0.0/24")
err = h.EnableNodeRoute("test", "test_get_route_machine", "10.0.0.0/24")
c.Assert(err, check.IsNil)
}
func (s *Suite) TestGetEnableRoutes(c *check.C) {
n, err := h.CreateNamespace("test")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "test_enable_route_machine")
c.Assert(err, check.NotNil)
route, err := netaddr.ParseIPPrefix(
"10.0.0.0/24",
)
c.Assert(err, check.IsNil)
route2, err := netaddr.ParseIPPrefix(
"150.0.10.0/25",
)
c.Assert(err, check.IsNil)
hi := tailcfg.Hostinfo{
RoutableIPs: []netaddr.IPPrefix{route, route2},
}
hostinfo, err := json.Marshal(hi)
c.Assert(err, check.IsNil)
m := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "test_enable_route_machine",
NamespaceID: n.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
HostInfo: datatypes.JSON(hostinfo),
}
h.db.Save(&m)
availableRoutes, err := h.GetAdvertisedNodeRoutes("test", "test_enable_route_machine")
c.Assert(err, check.IsNil)
c.Assert(len(*availableRoutes), check.Equals, 2)
enabledRoutes, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes), check.Equals, 0)
err = h.EnableNodeRoute("test", "test_enable_route_machine", "192.168.0.0/24")
c.Assert(err, check.NotNil)
err = h.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
c.Assert(err, check.IsNil)
enabledRoutes1, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 1)
// Adding it twice will just let it pass through
err = h.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
c.Assert(err, check.IsNil)
enabledRoutes2, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes2), check.Equals, 1)
err = h.EnableNodeRoute("test", "test_enable_route_machine", "150.0.10.0/25")
c.Assert(err, check.IsNil)
enabledRoutes3, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes3), check.Equals, 2)
}

View File

@@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
set -e -o pipefail
commit="$1"