Compare commits

..

28 Commits

Author SHA1 Message Date
Juan Font
8a614dabc0 Headscale is from no-juan 2021-08-06 00:23:07 +02:00
Juan Font
c95cf15731 Fixed log message 2021-08-06 00:21:34 +02:00
Juan Font
e7ce902f9d Merge pull request #75 from kradalby/syncmap
Fix deadlock issue
2021-08-06 00:19:34 +02:00
Juan Font
d421c7b665 Merge pull request #74 from kradalby/deadlock-logging
Switch to a structured logger
2021-08-06 00:18:40 +02:00
Kristoffer Dalby
1abc68ccf4 Removes locks causing deadlock
This commit removes most of the locks in the PollingMap handler as there
was combinations that caused deadlocks. Instead of doing a plain map and
doing the locking ourselves, we use sync.Map which handles it for us.
2021-08-05 22:14:37 +01:00
Kristoffer Dalby
575b15e5fa Add more trace logging 2021-08-05 21:47:06 +01:00
Kristoffer Dalby
a8c8a358d0 Make log keys lowercase 2021-08-05 20:57:47 +01:00
Kristoffer Dalby
cd2ca137c0 Make log_level user configurable 2021-08-05 19:19:25 +01:00
Kristoffer Dalby
0660867a16 Correct url 2021-08-05 18:58:15 +01:00
Kristoffer Dalby
b1200140b8 Convert cli/utils.go 2021-08-05 18:26:49 +01:00
Kristoffer Dalby
d10b57b317 Convert namespaces.go 2021-08-05 18:23:02 +01:00
Kristoffer Dalby
42bf566fff Convert acls.go 2021-08-05 18:18:18 +01:00
Kristoffer Dalby
0bb2fabc6c Convert missing from api.go 2021-08-05 18:16:21 +01:00
Kristoffer Dalby
ee704f8ef3 Initial port to zerologger 2021-08-05 18:11:26 +01:00
Juan Font
4aad3b7933 Improved README.md on ip_prefix 2021-08-03 20:38:23 +02:00
Juan Font
6091373b53 Merge pull request #63 from juanfont/use-kv-for-updates
Added communication between Serve and CLI using KV table
2021-08-03 20:30:33 +02:00
Juan Font
3879120967 Merge pull request #72 from kradalby/ip-pool
Make IP Prefix configurable and available ip deterministic
2021-08-03 20:27:42 +02:00
Kristoffer Dalby
465669f650 Merge pull request #1 from kradalby/ip-pool-test
Fix empty ip issue and remove network/broadcast addresses
2021-08-03 10:12:09 +01:00
Kristoffer Dalby
ea615e3a26 Do not issue "network" or "broadcast" addresses (0 or 255) 2021-08-03 10:06:42 +01:00
Kristoffer Dalby
d3349aa4d1 Add test to ensure we can deal with empty ips from database 2021-08-03 09:26:28 +01:00
Kristoffer Dalby
73207decfd Check that IP is set before parsing
Machine is saved to db before it is assigned an ip, so we might have
empty ip fields coming back.
2021-08-03 07:42:11 +01:00
Kristoffer Dalby
eda6e560c3 debug logging 2021-08-02 22:51:50 +01:00
Kristoffer Dalby
95de823b72 Add test to ensure we can read back ips 2021-08-02 22:39:18 +01:00
Kristoffer Dalby
9f85efffd5 Update readme 2021-08-02 22:06:15 +01:00
Kristoffer Dalby
b5841c8a8b Rework getAvailableIp
This commit reworks getAvailableIp with a "simpler" version that will
look for the first available IP address in our IP Prefix.

There is a couple of ideas behind this:

* Make the host IPs reasonably predictable and in within similar
  subnets, which should simplify ACLs for subnets
* The code is not random, but deterministic so we can have tests
* The code is a bit more understandable (no bit shift magic)
2021-08-02 21:57:45 +01:00
Kristoffer Dalby
309f868a21 Make IP prefix configurable
This commit makes the IP prefix used to generate addresses configurable
to users. This can be useful if you would like to use a smaller range or
if your current setup is overlapping with the current range.

The current range is left as a default
2021-08-02 20:06:26 +01:00
Juan Font Alonso
461a893ee4 Added log message when sending updates 2021-07-25 20:47:51 +02:00
Juan Font Alonso
97f7c90092 Added communication between Serve and CLI using KV table (helps in #52) 2021-07-25 17:59:48 +02:00
19 changed files with 682 additions and 134 deletions

View File

@@ -24,6 +24,7 @@ Headscale implements this coordination server.
- [x] Node registration via pre-auth keys (including reusable keys, and ephemeral node support) - [x] Node registration via pre-auth keys (including reusable keys, and ephemeral node support)
- [X] JSON-formatted output - [X] JSON-formatted output
- [X] ACLs - [X] ACLs
- [X] Support for alternative IP ranges in the tailnets (default Tailscale's 100.64.0.0/10)
- [ ] Share nodes between ~~users~~ namespaces - [ ] Share nodes between ~~users~~ namespaces
- [ ] DNS - [ ] DNS
@@ -113,9 +114,15 @@ Headscale's configuration file is named `config.json` or `config.yaml`. Headscal
``` ```
"server_url": "http://192.168.1.12:8080", "server_url": "http://192.168.1.12:8080",
"listen_addr": "0.0.0.0:8080", "listen_addr": "0.0.0.0:8080",
"ip_prefix": "100.64.0.0/10"
``` ```
`server_url` is the external URL via which Headscale is reachable. `listen_addr` is the IP address and port the Headscale program should listen on. `server_url` is the external URL via which Headscale is reachable. `listen_addr` is the IP address and port the Headscale program should listen on. `ip_prefix` is the IP prefix (range) in which IP addresses for nodes will be allocated (default 100.64.0.0/10, e.g., 192.168.4.0/24, 10.0.0.0/8)
```
"log_level": "debug"
```
`log_level` can be used to set the Log level for Headscale, it defaults to `debug`, and the available levels are: `trace`, `debug`, `info`, `warn` and `error`.
``` ```
"private_key_path": "private.key", "private_key_path": "private.key",

View File

@@ -4,11 +4,12 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log"
"os" "os"
"strconv" "strconv"
"strings" "strings"
"github.com/rs/zerolog/log"
"github.com/tailscale/hujson" "github.com/tailscale/hujson"
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@@ -66,7 +67,8 @@ func (h *Headscale) generateACLRules() (*[]tailcfg.FilterRule, error) {
for j, u := range a.Users { for j, u := range a.Users {
srcs, err := h.generateACLPolicySrcIP(u) srcs, err := h.generateACLPolicySrcIP(u)
if err != nil { if err != nil {
log.Printf("Error parsing ACL %d, User %d", i, j) log.Error().
Msgf("Error parsing ACL %d, User %d", i, j)
return nil, err return nil, err
} }
srcIPs = append(srcIPs, *srcs...) srcIPs = append(srcIPs, *srcs...)
@@ -77,7 +79,8 @@ func (h *Headscale) generateACLRules() (*[]tailcfg.FilterRule, error) {
for j, d := range a.Ports { for j, d := range a.Ports {
dests, err := h.generateACLPolicyDestPorts(d) dests, err := h.generateACLPolicyDestPorts(d)
if err != nil { if err != nil {
log.Printf("Error parsing ACL %d, Port %d", i, j) log.Error().
Msgf("Error parsing ACL %d, Port %d", i, j)
return nil, err return nil, err
} }
destPorts = append(destPorts, *dests...) destPorts = append(destPorts, *dests...)

284
api.go
View File

@@ -6,10 +6,11 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log"
"net/http" "net/http"
"time" "time"
"github.com/rs/zerolog/log"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd" "github.com/klauspost/compress/zstd"
"gorm.io/datatypes" "gorm.io/datatypes"
@@ -63,21 +64,27 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
mKeyStr := c.Param("id") mKeyStr := c.Param("id")
mKey, err := wgkey.ParseHex(mKeyStr) mKey, err := wgkey.ParseHex(mKeyStr)
if err != nil { if err != nil {
log.Printf("Cannot parse machine key: %s", err) log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot parse machine key")
c.String(http.StatusInternalServerError, "Sad!") c.String(http.StatusInternalServerError, "Sad!")
return return
} }
req := tailcfg.RegisterRequest{} req := tailcfg.RegisterRequest{}
err = decode(body, &req, &mKey, h.privateKey) err = decode(body, &req, &mKey, h.privateKey)
if err != nil { if err != nil {
log.Printf("Cannot decode message: %s", err) log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot decode message")
c.String(http.StatusInternalServerError, "Very sad!") c.String(http.StatusInternalServerError, "Very sad!")
return return
} }
var m Machine var m Machine
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
log.Println("New Machine!") log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
m = Machine{ m = Machine{
Expiry: &req.Expiry, Expiry: &req.Expiry,
MachineKey: mKey.HexString(), MachineKey: mKey.HexString(),
@@ -85,7 +92,10 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
NodeKey: wgkey.Key(req.NodeKey).HexString(), NodeKey: wgkey.Key(req.NodeKey).HexString(),
} }
if err := h.db.Create(&m).Error; err != nil { if err := h.db.Create(&m).Error; err != nil {
log.Printf("Could not create row: %s", err) log.Error().
Str("handler", "Registration").
Err(err).
Msg("Could not create row")
return return
} }
} }
@@ -100,13 +110,20 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
// We have the updated key! // We have the updated key!
if m.NodeKey == wgkey.Key(req.NodeKey).HexString() { if m.NodeKey == wgkey.Key(req.NodeKey).HexString() {
if m.Registered { if m.Registered {
log.Printf("[%s] Client is registered and we have the current NodeKey. All clear to /map", m.Name) log.Debug().
Str("handler", "Registration").
Str("machine", m.Name).
Msg("Client is registered and we have the current NodeKey. All clear to /map")
resp.AuthURL = "" resp.AuthURL = ""
resp.MachineAuthorized = true resp.MachineAuthorized = true
resp.User = *m.Namespace.toUser() resp.User = *m.Namespace.toUser()
respBody, err := encode(resp, &mKey, h.privateKey) respBody, err := encode(resp, &mKey, h.privateKey)
if err != nil { if err != nil {
log.Printf("Cannot encode message: %s", err) log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "") c.String(http.StatusInternalServerError, "")
return return
} }
@@ -114,12 +131,18 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
return return
} }
log.Printf("[%s] Not registered and not NodeKey rotation. Sending a authurl to register", m.Name) log.Debug().
Str("handler", "Registration").
Str("machine", m.Name).
Msg("Not registered and not NodeKey rotation. Sending a authurl to register")
resp.AuthURL = fmt.Sprintf("%s/register?key=%s", resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
h.cfg.ServerURL, mKey.HexString()) h.cfg.ServerURL, mKey.HexString())
respBody, err := encode(resp, &mKey, h.privateKey) respBody, err := encode(resp, &mKey, h.privateKey)
if err != nil { if err != nil {
log.Printf("Cannot encode message: %s", err) log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "") c.String(http.StatusInternalServerError, "")
return return
} }
@@ -129,7 +152,10 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
// The NodeKey we have matches OldNodeKey, which means this is a refresh after an key expiration // The NodeKey we have matches OldNodeKey, which means this is a refresh after an key expiration
if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() { if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() {
log.Printf("[%s] We have the OldNodeKey in the database. This is a key refresh", m.Name) log.Debug().
Str("handler", "Registration").
Str("machine", m.Name).
Msg("We have the OldNodeKey in the database. This is a key refresh")
m.NodeKey = wgkey.Key(req.NodeKey).HexString() m.NodeKey = wgkey.Key(req.NodeKey).HexString()
h.db.Save(&m) h.db.Save(&m)
@@ -137,7 +163,10 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
resp.User = *m.Namespace.toUser() resp.User = *m.Namespace.toUser()
respBody, err := encode(resp, &mKey, h.privateKey) respBody, err := encode(resp, &mKey, h.privateKey)
if err != nil { if err != nil {
log.Printf("Cannot encode message: %s", err) log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "Extremely sad!") c.String(http.StatusInternalServerError, "Extremely sad!")
return return
} }
@@ -148,25 +177,38 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
// We arrive here after a client is restarted without finalizing the authentication flow or // We arrive here after a client is restarted without finalizing the authentication flow or
// when headscale is stopped in the middle of the auth process. // when headscale is stopped in the middle of the auth process.
if m.Registered { if m.Registered {
log.Printf("[%s] The node is sending us a new NodeKey, but machine is registered. All clear for /map", m.Name) log.Debug().
Str("handler", "Registration").
Str("machine", m.Name).
Msg("The node is sending us a new NodeKey, but machine is registered. All clear for /map")
resp.AuthURL = "" resp.AuthURL = ""
resp.MachineAuthorized = true resp.MachineAuthorized = true
resp.User = *m.Namespace.toUser() resp.User = *m.Namespace.toUser()
respBody, err := encode(resp, &mKey, h.privateKey) respBody, err := encode(resp, &mKey, h.privateKey)
if err != nil { if err != nil {
log.Printf("Cannot encode message: %s", err) log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "") c.String(http.StatusInternalServerError, "")
return return
} }
c.Data(200, "application/json; charset=utf-8", respBody) c.Data(200, "application/json; charset=utf-8", respBody)
return return
} }
log.Printf("[%s] The node is sending us a new NodeKey, sending auth url", m.Name)
log.Debug().
Str("handler", "Registration").
Str("machine", m.Name).
Msg("The node is sending us a new NodeKey, sending auth url")
resp.AuthURL = fmt.Sprintf("%s/register?key=%s", resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
h.cfg.ServerURL, mKey.HexString()) h.cfg.ServerURL, mKey.HexString())
respBody, err := encode(resp, &mKey, h.privateKey) respBody, err := encode(resp, &mKey, h.privateKey)
if err != nil { if err != nil {
log.Printf("Cannot encode message: %s", err) log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "") c.String(http.StatusInternalServerError, "")
return return
} }
@@ -183,28 +225,45 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
// //
// At this moment the updates are sent in a quite horrendous way, but they kinda work. // At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) PollNetMapHandler(c *gin.Context) { func (h *Headscale) PollNetMapHandler(c *gin.Context) {
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Msg("PollNetMapHandler called")
body, _ := io.ReadAll(c.Request.Body) body, _ := io.ReadAll(c.Request.Body)
mKeyStr := c.Param("id") mKeyStr := c.Param("id")
mKey, err := wgkey.ParseHex(mKeyStr) mKey, err := wgkey.ParseHex(mKeyStr)
if err != nil { if err != nil {
log.Printf("Cannot parse client key: %s", err) log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot parse client key")
c.String(http.StatusBadRequest, "") c.String(http.StatusBadRequest, "")
return return
} }
req := tailcfg.MapRequest{} req := tailcfg.MapRequest{}
err = decode(body, &req, &mKey, h.privateKey) err = decode(body, &req, &mKey, h.privateKey)
if err != nil { if err != nil {
log.Printf("Cannot decode message: %s", err) log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot decode message")
c.String(http.StatusBadRequest, "") c.String(http.StatusBadRequest, "")
return return
} }
var m Machine var m Machine
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) { if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
log.Printf("Ignoring request, cannot find machine with key %s", mKey.HexString()) log.Warn().
Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString())
c.String(http.StatusUnauthorized, "") c.String(http.StatusUnauthorized, "")
return return
} }
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Found machine in database")
hostinfo, _ := json.Marshal(req.Hostinfo) hostinfo, _ := json.Marshal(req.Hostinfo)
m.Name = req.Hostinfo.Hostname m.Name = req.Hostinfo.Hostname
@@ -227,17 +286,34 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
} }
h.db.Save(&m) h.db.Save(&m)
pollData := make(chan []byte, 1)
update := make(chan []byte, 1) update := make(chan []byte, 1)
cancelKeepAlive := make(chan []byte, 1)
pollData := make(chan []byte, 1)
defer close(pollData) defer close(pollData)
cancelKeepAlive := make(chan []byte, 1)
defer close(cancelKeepAlive) defer close(cancelKeepAlive)
h.pollMu.Lock()
h.clientsPolling[m.ID] = update log.Trace().
h.pollMu.Unlock() Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Locking poll mutex")
h.clientsPolling.Store(m.ID, update)
log.Trace().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Msg("Unlocking poll mutex")
data, err := h.getMapResponse(mKey, req, m) data, err := h.getMapResponse(mKey, req, m)
if err != nil { if err != nil {
log.Error().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Err(err).
Msg("Failed to get Map response")
c.String(http.StatusInternalServerError, ":(") c.String(http.StatusInternalServerError, ":(")
return return
} }
@@ -247,50 +323,90 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
// empty endpoints to peers) // empty endpoints to peers)
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696 // Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
log.Printf("[%s] ReadOnly=%t OmitPeers=%t Stream=%t", m.Name, req.ReadOnly, req.OmitPeers, req.Stream) log.Debug().
Str("handler", "PollNetMap").
Str("id", c.Param("id")).
Str("machine", m.Name).
Bool("readOnly", req.ReadOnly).
Bool("omitPeers", req.OmitPeers).
Bool("stream", req.Stream).
Msg("Client map request processed")
if req.ReadOnly { if req.ReadOnly {
log.Printf("[%s] Client is starting up. Asking for DERP map", m.Name) log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client is starting up. Asking for DERP map")
c.Data(200, "application/json; charset=utf-8", *data) c.Data(200, "application/json; charset=utf-8", *data)
return return
} }
if req.OmitPeers && !req.Stream { if req.OmitPeers && !req.Stream {
log.Printf("[%s] Client sent endpoint update and is ok with a response without peer list", m.Name) log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client sent endpoint update and is ok with a response without peer list")
c.Data(200, "application/json; charset=utf-8", *data) c.Data(200, "application/json; charset=utf-8", *data)
return return
} else if req.OmitPeers && req.Stream { } else if req.OmitPeers && req.Stream {
log.Printf("[%s] Warning, ignoring request, don't know how to handle it", m.Name) log.Warn().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Ignoring request, don't know how to handle it")
c.String(http.StatusBadRequest, "") c.String(http.StatusBadRequest, "")
return return
} }
log.Printf("[%s] Client is ready to access the tailnet", m.Name) log.Info().
log.Printf("[%s] Sending initial map", m.Name) Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Client is ready to access the tailnet")
log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Sending initial map")
pollData <- *data pollData <- *data
log.Printf("[%s] Notifying peers", m.Name) log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Notifying peers")
peers, _ := h.getPeers(m) peers, _ := h.getPeers(m)
h.pollMu.Lock()
for _, p := range *peers { for _, p := range *peers {
pUp, ok := h.clientsPolling[uint64(p.ID)] pUp, ok := h.clientsPolling.Load(uint64(p.ID))
if ok { if ok {
log.Printf("[%s] Notifying peer %s (%s)", m.Name, p.Name, p.Addresses[0]) log.Info().
pUp <- []byte{} Str("handler", "PollNetMap").
Str("machine", m.Name).
Str("peer", m.Name).
Str("address", p.Addresses[0].String()).
Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0])
pUp.(chan []byte) <- []byte{}
} else { } else {
log.Printf("[%s] Peer %s does not appear to be polling", m.Name, p.Name) log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Str("peer", m.Name).
Msgf("Peer %s does not appear to be polling", p.Name)
} }
} }
h.pollMu.Unlock()
go h.keepAlive(cancelKeepAlive, pollData, mKey, req, m) go h.keepAlive(cancelKeepAlive, pollData, mKey, req, m)
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-pollData: case data := <-pollData:
log.Printf("[%s] Sending data (%d bytes)", m.Name, len(data)) log.Trace().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Int("bytes", len(data)).
Msg("Sending data")
_, err := w.Write(data) _, err := w.Write(data)
if err != nil { if err != nil {
log.Printf("[%s] Cannot write data: %s", m.Name, err) log.Error().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Err(err).
Msg("Cannot write data")
} }
now := time.Now().UTC() now := time.Now().UTC()
m.LastSeen = &now m.LastSeen = &now
@@ -298,27 +414,39 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
return true return true
case <-update: case <-update:
log.Printf("[%s] Received a request for update", m.Name) log.Debug().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("Received a request for update")
data, err := h.getMapResponse(mKey, req, m) data, err := h.getMapResponse(mKey, req, m)
if err != nil { if err != nil {
log.Printf("[%s] Could not get the map update: %s", m.Name, err) log.Error().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Err(err).
Msg("Could not get the map update")
} }
_, err = w.Write(*data) _, err = w.Write(*data)
if err != nil { if err != nil {
log.Printf("[%s] Could not write the map response: %s", m.Name, err) log.Error().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Err(err).
Msg("Could not write the map response")
} }
return true return true
case <-c.Request.Context().Done(): case <-c.Request.Context().Done():
log.Printf("[%s] The client has closed the connection", m.Name) log.Info().
Str("handler", "PollNetMap").
Str("machine", m.Name).
Msg("The client has closed the connection")
now := time.Now().UTC() now := time.Now().UTC()
m.LastSeen = &now m.LastSeen = &now
h.db.Save(&m) h.db.Save(&m)
h.pollMu.Lock()
cancelKeepAlive <- []byte{} cancelKeepAlive <- []byte{}
delete(h.clientsPolling, m.ID) h.clientsPolling.Delete(m.ID)
close(update) close(update)
h.pollMu.Unlock()
return false return false
} }
@@ -335,10 +463,16 @@ func (h *Headscale) keepAlive(cancel chan []byte, pollData chan []byte, mKey wgk
h.pollMu.Lock() h.pollMu.Lock()
data, err := h.getMapKeepAliveResponse(mKey, req, m) data, err := h.getMapKeepAliveResponse(mKey, req, m)
if err != nil { if err != nil {
log.Printf("Error generating the keep alive msg: %s", err) log.Error().
Str("func", "keepAlive").
Err(err).
Msg("Error generating the keep alive msg")
return return
} }
log.Printf("[%s] Sending keepalive", m.Name) log.Debug().
Str("func", "keepAlive").
Str("machine", m.Name).
Msg("Sending keepalive")
pollData <- *data pollData <- *data
h.pollMu.Unlock() h.pollMu.Unlock()
time.Sleep(60 * time.Second) time.Sleep(60 * time.Second)
@@ -347,14 +481,24 @@ func (h *Headscale) keepAlive(cancel chan []byte, pollData chan []byte, mKey wgk
} }
func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m Machine) (*[]byte, error) { func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m Machine) (*[]byte, error) {
log.Trace().
Str("func", "getMapResponse").
Str("machine", req.Hostinfo.Hostname).
Msg("Creating Map response")
node, err := m.toNode() node, err := m.toNode()
if err != nil { if err != nil {
log.Printf("Cannot convert to node: %s", err) log.Error().
Str("func", "getMapResponse").
Err(err).
Msg("Cannot convert to node")
return nil, err return nil, err
} }
peers, err := h.getPeers(m) peers, err := h.getPeers(m)
if err != nil { if err != nil {
log.Printf("Cannot fetch peers: %s", err) log.Error().
Str("func", "getMapResponse").
Err(err).
Msg("Cannot fetch peers")
return nil, err return nil, err
} }
@@ -426,25 +570,49 @@ func (h *Headscale) getMapKeepAliveResponse(mKey wgkey.Key, req tailcfg.MapReque
} }
func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key, req tailcfg.RegisterRequest, m Machine) { func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key, req tailcfg.RegisterRequest, m Machine) {
log.Debug().
Str("func", "handleAuthKey").
Str("machine", req.Hostinfo.Hostname).
Msgf("Processing auth key for %s", req.Hostinfo.Hostname)
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
pak, err := h.checkKeyValidity(req.Auth.AuthKey) pak, err := h.checkKeyValidity(req.Auth.AuthKey)
if err != nil { if err != nil {
resp.MachineAuthorized = false resp.MachineAuthorized = false
respBody, err := encode(resp, &idKey, h.privateKey) respBody, err := encode(resp, &idKey, h.privateKey)
if err != nil { if err != nil {
log.Printf("Cannot encode message: %s", err) log.Error().
Str("func", "handleAuthKey").
Str("machine", m.Name).
Err(err).
Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "") c.String(http.StatusInternalServerError, "")
return return
} }
c.Data(200, "application/json; charset=utf-8", respBody) c.Data(200, "application/json; charset=utf-8", respBody)
log.Printf("[%s] Failed authentication via AuthKey", m.Name) log.Error().
Str("func", "handleAuthKey").
Str("machine", m.Name).
Msg("Failed authentication via AuthKey")
return return
} }
log.Debug().
Str("func", "handleAuthKey").
Str("machine", m.Name).
Msg("Authentication key was valid, proceeding to acquire an IP address")
ip, err := h.getAvailableIP() ip, err := h.getAvailableIP()
if err != nil { if err != nil {
log.Println(err) log.Error().
Str("func", "handleAuthKey").
Str("machine", m.Name).
Msg("Failed to find an available IP")
return return
} }
log.Info().
Str("func", "handleAuthKey").
Str("machine", m.Name).
Str("ip", ip.String()).
Msgf("Assining %s to %s", ip, m.Name)
m.AuthKeyID = uint(pak.ID) m.AuthKeyID = uint(pak.ID)
m.IPAddress = ip.String() m.IPAddress = ip.String()
@@ -458,10 +626,18 @@ func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key,
resp.User = *pak.Namespace.toUser() resp.User = *pak.Namespace.toUser()
respBody, err := encode(resp, &idKey, h.privateKey) respBody, err := encode(resp, &idKey, h.privateKey)
if err != nil { if err != nil {
log.Printf("Cannot encode message: %s", err) log.Error().
Str("func", "handleAuthKey").
Str("machine", m.Name).
Err(err).
Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "Extremely sad!") c.String(http.StatusInternalServerError, "Extremely sad!")
return return
} }
c.Data(200, "application/json; charset=utf-8", respBody) c.Data(200, "application/json; charset=utf-8", respBody)
log.Printf("[%s] Successfully authenticated via AuthKey", m.Name) log.Info().
Str("func", "handleAuthKey").
Str("machine", m.Name).
Str("ip", ip.String()).
Msg("Successfully authenticated via AuthKey")
} }

44
app.go
View File

@@ -3,16 +3,18 @@ package headscale
import ( import (
"errors" "errors"
"fmt" "fmt"
"log"
"net/http" "net/http"
"os" "os"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/rs/zerolog/log"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
"gorm.io/gorm" "gorm.io/gorm"
"inet.af/netaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/wgkey" "tailscale.com/types/wgkey"
) )
@@ -24,6 +26,7 @@ type Config struct {
PrivateKeyPath string PrivateKeyPath string
DerpMap *tailcfg.DERPMap DerpMap *tailcfg.DERPMap
EphemeralNodeInactivityTimeout time.Duration EphemeralNodeInactivityTimeout time.Duration
IPPrefix netaddr.IPPrefix
DBtype string DBtype string
DBpath string DBpath string
@@ -56,7 +59,7 @@ type Headscale struct {
aclRules *[]tailcfg.FilterRule aclRules *[]tailcfg.FilterRule
pollMu sync.Mutex pollMu sync.Mutex
clientsPolling map[uint64]chan []byte // this is by all means a hackity hack clientsPolling sync.Map
} }
// NewHeadscale returns the Headscale app // NewHeadscale returns the Headscale app
@@ -96,7 +99,6 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
return nil, err return nil, err
} }
h.clientsPolling = make(map[uint64]chan []byte)
return &h, nil return &h, nil
} }
@@ -118,27 +120,41 @@ func (h *Headscale) ExpireEphemeralNodes(milliSeconds int64) {
func (h *Headscale) expireEphemeralNodesWorker() { func (h *Headscale) expireEphemeralNodesWorker() {
namespaces, err := h.ListNamespaces() namespaces, err := h.ListNamespaces()
if err != nil { if err != nil {
log.Printf("Error listing namespaces: %s", err) log.Error().Err(err).Msg("Error listing namespaces")
return return
} }
for _, ns := range *namespaces { for _, ns := range *namespaces {
machines, err := h.ListMachinesInNamespace(ns.Name) machines, err := h.ListMachinesInNamespace(ns.Name)
if err != nil { if err != nil {
log.Printf("Error listing machines in namespace %s: %s", ns.Name, err) log.Error().Err(err).Str("namespace", ns.Name).Msg("Error listing machines in namespace")
return return
} }
for _, m := range *machines { for _, m := range *machines {
if m.AuthKey != nil && m.LastSeen != nil && m.AuthKey.Ephemeral && time.Now().After(m.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) { if m.AuthKey != nil && m.LastSeen != nil && m.AuthKey.Ephemeral && time.Now().After(m.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
log.Printf("[%s] Ephemeral client removed from database\n", m.Name) log.Info().Str("machine", m.Name).Msg("Ephemeral client removed from database")
err = h.db.Unscoped().Delete(m).Error err = h.db.Unscoped().Delete(m).Error
if err != nil { if err != nil {
log.Printf("[%s] 🤮 Cannot delete ephemeral machine from the database: %s", m.Name, err) log.Error().Err(err).Str("machine", m.Name).Msg("🤮 Cannot delete ephemeral machine from the database")
} }
} }
} }
} }
} }
// WatchForKVUpdates checks the KV DB table for requests to perform tailnet upgrades
// This is a way to communitate the CLI with the headscale server
func (h *Headscale) watchForKVUpdates(milliSeconds int64) {
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
for range ticker.C {
h.watchForKVUpdatesWorker()
}
}
func (h *Headscale) watchForKVUpdatesWorker() {
h.checkForNamespacesPendingUpdates()
// more functions will come here in the future
}
// Serve launches a GIN server with the Headscale API // Serve launches a GIN server with the Headscale API
func (h *Headscale) Serve() error { func (h *Headscale) Serve() error {
r := gin.Default() r := gin.Default()
@@ -147,9 +163,12 @@ func (h *Headscale) Serve() error {
r.POST("/machine/:id/map", h.PollNetMapHandler) r.POST("/machine/:id/map", h.PollNetMapHandler)
r.POST("/machine/:id", h.RegistrationHandler) r.POST("/machine/:id", h.RegistrationHandler)
var err error var err error
go h.watchForKVUpdates(5000)
if h.cfg.TLSLetsEncryptHostname != "" { if h.cfg.TLSLetsEncryptHostname != "" {
if !strings.HasPrefix(h.cfg.ServerURL, "https://") { if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
log.Println("WARNING: listening with TLS but ServerURL does not start with https://") log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
} }
m := autocert.Manager{ m := autocert.Manager{
@@ -172,7 +191,10 @@ func (h *Headscale) Serve() error {
// port 80 for the certificate validation in addition to the headscale // port 80 for the certificate validation in addition to the headscale
// service, which can be configured to run on any other port. // service, which can be configured to run on any other port.
go func() { go func() {
log.Fatal(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, m.HTTPHandler(http.HandlerFunc(h.redirect))))
log.Fatal().
Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, m.HTTPHandler(http.HandlerFunc(h.redirect)))).
Msg("failed to set up a HTTP server")
}() }()
err = s.ListenAndServeTLS("", "") err = s.ListenAndServeTLS("", "")
} else { } else {
@@ -180,12 +202,12 @@ func (h *Headscale) Serve() error {
} }
} else if h.cfg.TLSCertPath == "" { } else if h.cfg.TLSCertPath == "" {
if !strings.HasPrefix(h.cfg.ServerURL, "http://") { if !strings.HasPrefix(h.cfg.ServerURL, "http://") {
log.Println("WARNING: listening without TLS but ServerURL does not start with http://") log.Warn().Msg("Listening without TLS but ServerURL does not start with http://")
} }
err = r.Run(h.cfg.Addr) err = r.Run(h.cfg.Addr)
} else { } else {
if !strings.HasPrefix(h.cfg.ServerURL, "https://") { if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
log.Println("WARNING: listening with TLS but ServerURL does not start with https://") log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
} }
err = r.RunTLS(h.cfg.Addr, h.cfg.TLSCertPath, h.cfg.TLSKeyPath) err = r.RunTLS(h.cfg.Addr, h.cfg.TLSCertPath, h.cfg.TLSKeyPath)
} }

View File

@@ -6,6 +6,7 @@ import (
"testing" "testing"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"inet.af/netaddr"
) )
func Test(t *testing.T) { func Test(t *testing.T) {
@@ -36,7 +37,9 @@ func (s *Suite) ResetDB(c *check.C) {
if err != nil { if err != nil {
c.Fatal(err) c.Fatal(err)
} }
cfg := Config{} cfg := Config{
IPPrefix: netaddr.MustParseIPPrefix("10.27.0.0/23"),
}
h = Headscale{ h = Headscale{
cfg: cfg, cfg: cfg,

View File

@@ -15,6 +15,7 @@ func (s *Suite) TestRegisterMachine(c *check.C) {
DiscoKey: "faa", DiscoKey: "faa",
Name: "testmachine", Name: "testmachine",
NamespaceID: n.ID, NamespaceID: n.ID,
IPAddress: "10.0.0.1",
} }
h.db.Save(&m) h.db.Save(&m)

View File

@@ -2,8 +2,9 @@ package cli
import ( import (
"fmt" "fmt"
"github.com/spf13/cobra"
"os" "os"
"github.com/spf13/cobra"
) )
func init() { func init() {
@@ -16,8 +17,7 @@ var rootCmd = &cobra.Command{
Long: ` Long: `
headscale is an open source implementation of the Tailscale control server headscale is an open source implementation of the Tailscale control server
Juan Font Alonso <juanfontalonso@gmail.com> - 2021 https://github.com/juanfont/headscale`,
https://gitlab.com/juanfont/headscale`,
} }
func Execute() { func Execute() {

View File

@@ -5,15 +5,16 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"time" "time"
"github.com/juanfont/headscale" "github.com/juanfont/headscale"
"github.com/rs/zerolog/log"
"github.com/spf13/viper" "github.com/spf13/viper"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
"inet.af/netaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
) )
@@ -36,6 +37,10 @@ func LoadConfig(path string) error {
viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache") viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache")
viper.SetDefault("tls_letsencrypt_challenge_type", "HTTP-01") viper.SetDefault("tls_letsencrypt_challenge_type", "HTTP-01")
viper.SetDefault("ip_prefix", "100.64.0.0/10")
viper.SetDefault("log_level", "debug")
err := viper.ReadInConfig() err := viper.ReadInConfig()
if err != nil { if err != nil {
return fmt.Errorf("Fatal error reading config file: %s \n", err) return fmt.Errorf("Fatal error reading config file: %s \n", err)
@@ -49,7 +54,8 @@ func LoadConfig(path string) error {
if (viper.GetString("tls_letsencrypt_hostname") != "") && (viper.GetString("tls_letsencrypt_challenge_type") == "TLS-ALPN-01") && (!strings.HasSuffix(viper.GetString("listen_addr"), ":443")) { if (viper.GetString("tls_letsencrypt_hostname") != "") && (viper.GetString("tls_letsencrypt_challenge_type") == "TLS-ALPN-01") && (!strings.HasSuffix(viper.GetString("listen_addr"), ":443")) {
// this is only a warning because there could be something sitting in front of headscale that redirects the traffic (e.g. an iptables rule) // this is only a warning because there could be something sitting in front of headscale that redirects the traffic (e.g. an iptables rule)
log.Println("Warning: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, headscale must be reachable on port 443, i.e. listen_addr should probably end in :443") log.Warn().
Msg("Warning: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, headscale must be reachable on port 443, i.e. listen_addr should probably end in :443")
} }
if (viper.GetString("tls_letsencrypt_challenge_type") != "HTTP-01") && (viper.GetString("tls_letsencrypt_challenge_type") != "TLS-ALPN-01") { if (viper.GetString("tls_letsencrypt_challenge_type") != "HTTP-01") && (viper.GetString("tls_letsencrypt_challenge_type") != "TLS-ALPN-01") {
@@ -79,9 +85,13 @@ func absPath(path string) string {
} }
func getHeadscaleApp() (*headscale.Headscale, error) { func getHeadscaleApp() (*headscale.Headscale, error) {
derpMap, err := loadDerpMap(absPath(viper.GetString("derp_map_path"))) derpPath := absPath(viper.GetString("derp_map_path"))
derpMap, err := loadDerpMap(derpPath)
if err != nil { if err != nil {
log.Printf("Could not load DERP servers map file: %s", err) log.Error().
Str("path", derpPath).
Err(err).
Msg("Could not load DERP servers map file")
} }
// Minimum inactivity time out is keepalive timeout (60s) plus a few seconds // Minimum inactivity time out is keepalive timeout (60s) plus a few seconds
@@ -97,6 +107,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
Addr: viper.GetString("listen_addr"), Addr: viper.GetString("listen_addr"),
PrivateKeyPath: absPath(viper.GetString("private_key_path")), PrivateKeyPath: absPath(viper.GetString("private_key_path")),
DerpMap: derpMap, DerpMap: derpMap,
IPPrefix: netaddr.MustParseIPPrefix(viper.GetString("ip_prefix")),
EphemeralNodeInactivityTimeout: viper.GetDuration("ephemeral_node_inactivity_timeout"), EphemeralNodeInactivityTimeout: viper.GetDuration("ephemeral_node_inactivity_timeout"),
@@ -125,9 +136,13 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
// We are doing this here, as in the future could be cool to have it also hot-reload // We are doing this here, as in the future could be cool to have it also hot-reload
if viper.GetString("acl_policy_path") != "" { if viper.GetString("acl_policy_path") != "" {
err = h.LoadACLPolicy(absPath(viper.GetString("acl_policy_path"))) aclPath := absPath(viper.GetString("acl_policy_path"))
err = h.LoadACLPolicy(aclPath)
if err != nil { if err != nil {
log.Printf("Could not load the ACL policy: %s", err) log.Error().
Str("path", aclPath).
Err(err).
Msg("Could not load the ACL policy")
} }
} }
@@ -157,24 +172,24 @@ func JsonOutput(result interface{}, errResult error, outputFormat string) {
if errResult != nil { if errResult != nil {
j, err = json.MarshalIndent(ErrorOutput{errResult.Error()}, "", "\t") j, err = json.MarshalIndent(ErrorOutput{errResult.Error()}, "", "\t")
if err != nil { if err != nil {
log.Fatalln(err) log.Fatal().Err(err)
} }
} else { } else {
j, err = json.MarshalIndent(result, "", "\t") j, err = json.MarshalIndent(result, "", "\t")
if err != nil { if err != nil {
log.Fatalln(err) log.Fatal().Err(err)
} }
} }
case "json-line": case "json-line":
if errResult != nil { if errResult != nil {
j, err = json.Marshal(ErrorOutput{errResult.Error()}) j, err = json.Marshal(ErrorOutput{errResult.Error()})
if err != nil { if err != nil {
log.Fatalln(err) log.Fatal().Err(err)
} }
} else { } else {
j, err = json.Marshal(result) j, err = json.Marshal(result)
if err != nil { if err != nil {
log.Fatalln(err) log.Fatal().Err(err)
} }
} }
} }

View File

@@ -1,15 +1,41 @@
package main package main
import ( import (
"log" "os"
"time"
"github.com/juanfont/headscale/cmd/headscale/cli" "github.com/juanfont/headscale/cmd/headscale/cli"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/spf13/viper"
) )
func main() { func main() {
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
log.Logger = log.Output(zerolog.ConsoleWriter{
Out: os.Stdout,
TimeFormat: time.RFC3339,
})
err := cli.LoadConfig("") err := cli.LoadConfig("")
if err != nil { if err != nil {
log.Fatalf(err.Error()) log.Fatal().Err(err)
}
logLevel := viper.GetString("log_level")
switch logLevel {
case "trace":
zerolog.SetGlobalLevel(zerolog.TraceLevel)
case "debug":
zerolog.SetGlobalLevel(zerolog.DebugLevel)
case "info":
zerolog.SetGlobalLevel(zerolog.InfoLevel)
case "warn":
zerolog.SetGlobalLevel(zerolog.WarnLevel)
case "error":
zerolog.SetGlobalLevel(zerolog.ErrorLevel)
default:
zerolog.SetGlobalLevel(zerolog.DebugLevel)
} }
cli.Execute() cli.Execute()

2
db.go
View File

@@ -79,6 +79,7 @@ func (h *Headscale) openDB() (*gorm.DB, error) {
return db, nil return db, nil
} }
// getValue returns the value for the given key in KV
func (h *Headscale) getValue(key string) (string, error) { func (h *Headscale) getValue(key string) (string, error) {
var row KV var row KV
if result := h.db.First(&row, "key = ?", key); errors.Is(result.Error, gorm.ErrRecordNotFound) { if result := h.db.First(&row, "key = ?", key); errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -87,6 +88,7 @@ func (h *Headscale) getValue(key string) (string, error) {
return row.Value, nil return row.Value, nil
} }
// setValue sets value for the given key in KV
func (h *Headscale) setValue(key string, value string) error { func (h *Headscale) setValue(key string, value string) error {
kv := KV{ kv := KV{
Key: key, Key: key,

1
go.mod
View File

@@ -9,6 +9,7 @@ require (
github.com/klauspost/compress v1.13.1 github.com/klauspost/compress v1.13.1
github.com/lib/pq v1.10.2 // indirect github.com/lib/pq v1.10.2 // indirect
github.com/mattn/go-sqlite3 v1.14.7 // indirect github.com/mattn/go-sqlite3 v1.14.7 // indirect
github.com/rs/zerolog v1.23.0 // indirect
github.com/spf13/cobra v1.1.3 github.com/spf13/cobra v1.1.3
github.com/spf13/viper v1.8.1 github.com/spf13/viper v1.8.1
github.com/tailscale/hujson v0.0.0-20200924210142-dde312d0d6a2 github.com/tailscale/hujson v0.0.0-20200924210142-dde312d0d6a2

2
go.sum
View File

@@ -683,6 +683,8 @@ github.com/rogpeppe/go-internal v1.6.2/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTE
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc=
github.com/rs/zerolog v1.23.0 h1:UskrK+saS9P9Y789yNNulYKdARjPZuS35B8gJF2x60g=
github.com/rs/zerolog v1.23.0/go.mod h1:6c7hFfxPOy7TacJc4Fcdi24/J0NKYGzjG8FWRI916Qo=
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/ryancurrah/gomodguard v1.1.0/go.mod h1:4O8tr7hBODaGE6VIhfJDHcwzh5GUccKSJBU0UMXJFVM= github.com/ryancurrah/gomodguard v1.1.0/go.mod h1:4O8tr7hBODaGE6VIhfJDHcwzh5GUccKSJBU0UMXJFVM=
github.com/ryanrolds/sqlclosecheck v0.3.0/go.mod h1:1gREqxyTGR3lVtpngyFo3hZAgk0KCtEdgEkHwDbigdA= github.com/ryanrolds/sqlclosecheck v0.3.0/go.mod h1:1gREqxyTGR3lVtpngyFo3hZAgk0KCtEdgEkHwDbigdA=

View File

@@ -65,7 +65,6 @@ tasks like creating namespaces, authkeys, etc.
headscale is an open source implementation of the Tailscale control server headscale is an open source implementation of the Tailscale control server
Juan Font Alonso <juanfontalonso@gmail.com> - 2021
https://gitlab.com/juanfont/headscale https://gitlab.com/juanfont/headscale
Usage: Usage:

View File

@@ -3,11 +3,12 @@ package headscale
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log"
"sort" "sort"
"strconv" "strconv"
"time" "time"
"github.com/rs/zerolog/log"
"gorm.io/datatypes" "gorm.io/datatypes"
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
@@ -157,7 +158,7 @@ func (h *Headscale) getPeers(m Machine) (*[]*tailcfg.Node, error) {
machines := []Machine{} machines := []Machine{}
if err := h.db.Where("namespace_id = ? AND machine_key <> ? AND registered", if err := h.db.Where("namespace_id = ? AND machine_key <> ? AND registered",
m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil { m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil {
log.Printf("Error accessing db: %s", err) log.Error().Err(err).Msg("Error accessing db")
return nil, err return nil, err
} }
@@ -200,19 +201,22 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
// DeleteMachine softs deletes a Machine from the database // DeleteMachine softs deletes a Machine from the database
func (h *Headscale) DeleteMachine(m *Machine) error { func (h *Headscale) DeleteMachine(m *Machine) error {
m.Registered = false m.Registered = false
namespaceID := m.NamespaceID
h.db.Save(&m) // we mark it as unregistered, just in case h.db.Save(&m) // we mark it as unregistered, just in case
if err := h.db.Delete(&m).Error; err != nil { if err := h.db.Delete(&m).Error; err != nil {
return err return err
} }
return nil
return h.RequestMapUpdates(namespaceID)
} }
// HardDeleteMachine hard deletes a Machine from the database // HardDeleteMachine hard deletes a Machine from the database
func (h *Headscale) HardDeleteMachine(m *Machine) error { func (h *Headscale) HardDeleteMachine(m *Machine) error {
namespaceID := m.NamespaceID
if err := h.db.Unscoped().Delete(&m).Error; err != nil { if err := h.db.Unscoped().Delete(&m).Error; err != nil {
return err return err
} }
return nil return h.RequestMapUpdates(namespaceID)
} }
// GetHostInfo returns a Hostinfo struct for the machine // GetHostInfo returns a Hostinfo struct for the machine

View File

@@ -1,6 +1,8 @@
package headscale package headscale
import ( import (
"encoding/json"
"gopkg.in/check.v1" "gopkg.in/check.v1"
) )
@@ -81,6 +83,15 @@ func (s *Suite) TestDeleteMachine(c *check.C) {
h.db.Save(&m) h.db.Save(&m)
err = h.DeleteMachine(&m) err = h.DeleteMachine(&m)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
v, err := h.getValue("namespaces_pending_updates")
c.Assert(err, check.IsNil)
names := []string{}
err = json.Unmarshal([]byte(v), &names)
c.Assert(err, check.IsNil)
c.Assert(names, check.DeepEquals, []string{n.Name})
h.checkForNamespacesPendingUpdates()
v, _ = h.getValue("namespaces_pending_updates")
c.Assert(v, check.Equals, "")
_, err = h.GetMachine(n.Name, "testmachine") _, err = h.GetMachine(n.Name, "testmachine")
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
} }

View File

@@ -1,10 +1,12 @@
package headscale package headscale
import ( import (
"encoding/json"
"errors" "errors"
"log" "fmt"
"time" "time"
"github.com/rs/zerolog/log"
"gorm.io/gorm" "gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
) )
@@ -31,7 +33,10 @@ func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
} }
n.Name = name n.Name = name
if err := h.db.Create(&n).Error; err != nil { if err := h.db.Create(&n).Error; err != nil {
log.Printf("Could not create row: %s", err) log.Error().
Str("func", "CreateNamespace").
Err(err).
Msg("Could not create row")
return nil, err return nil, err
} }
return &n, nil return &n, nil
@@ -103,6 +108,104 @@ func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error
return nil return nil
} }
// RequestMapUpdates signals the KV worker to update the maps for this namespace
func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
namespace := Namespace{}
if err := h.db.First(&namespace, namespaceID).Error; err != nil {
return err
}
v, err := h.getValue("namespaces_pending_updates")
if err != nil || v == "" {
err = h.setValue("namespaces_pending_updates", fmt.Sprintf(`["%s"]`, namespace.Name))
if err != nil {
return err
}
return nil
}
names := []string{}
err = json.Unmarshal([]byte(v), &names)
if err != nil {
err = h.setValue("namespaces_pending_updates", fmt.Sprintf(`["%s"]`, namespace.Name))
if err != nil {
return err
}
return nil
}
names = append(names, namespace.Name)
data, err := json.Marshal(names)
if err != nil {
log.Error().
Str("func", "RequestMapUpdates").
Err(err).
Msg("Could not marshal namespaces_pending_updates")
return err
}
return h.setValue("namespaces_pending_updates", string(data))
}
func (h *Headscale) checkForNamespacesPendingUpdates() {
v, err := h.getValue("namespaces_pending_updates")
if err != nil {
return
}
if v == "" {
return
}
names := []string{}
err = json.Unmarshal([]byte(v), &names)
if err != nil {
return
}
for _, name := range names {
log.Trace().
Str("func", "RequestMapUpdates").
Str("machine", name).
Msg("Sending updates to nodes in namespace")
machines, err := h.ListMachinesInNamespace(name)
if err != nil {
continue
}
for _, m := range *machines {
peers, _ := h.getPeers(m)
for _, p := range *peers {
pUp, ok := h.clientsPolling.Load(uint64(p.ID))
if ok {
log.Info().
Str("func", "checkForNamespacesPendingUpdates").
Str("machine", m.Name).
Str("peer", m.Name).
Str("address", p.Addresses[0].String()).
Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0])
pUp.(chan []byte) <- []byte{}
} else {
log.Info().
Str("func", "checkForNamespacesPendingUpdates").
Str("machine", m.Name).
Str("peer", m.Name).
Msgf("Peer %s does not appear to be polling", p.Name)
}
}
}
}
newV, err := h.getValue("namespaces_pending_updates")
if err != nil {
return
}
if v == newV { // only clear when no changes, so we notified everybody
err = h.setValue("namespaces_pending_updates", "")
if err != nil {
log.Error().
Str("func", "checkForNamespacesPendingUpdates").
Err(err).
Msg("Could not save to KV")
return
}
}
}
func (n *Namespace) toUser() *tailcfg.User { func (n *Namespace) toUser() *tailcfg.User {
u := tailcfg.User{ u := tailcfg.User{
ID: tailcfg.UserID(n.ID), ID: tailcfg.UserID(n.ID),

View File

@@ -52,8 +52,8 @@ func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr
peers, _ := h.getPeers(*m) peers, _ := h.getPeers(*m)
h.pollMu.Lock() h.pollMu.Lock()
for _, p := range *peers { for _, p := range *peers {
if pUp, ok := h.clientsPolling[uint64(p.ID)]; ok { if pUp, ok := h.clientsPolling.Load(uint64(p.ID)); ok {
pUp <- []byte{} pUp.(chan []byte) <- []byte{}
} }
} }
h.pollMu.Unlock() h.pollMu.Unlock()

View File

@@ -7,18 +7,12 @@ package headscale
import ( import (
"crypto/rand" "crypto/rand"
"encoding/binary"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net"
"time"
mathrand "math/rand"
"golang.org/x/crypto/nacl/box" "golang.org/x/crypto/nacl/box"
"gorm.io/gorm" "inet.af/netaddr"
"tailscale.com/types/wgkey" "tailscale.com/types/wgkey"
) )
@@ -77,47 +71,71 @@ func encodeMsg(b []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, err
return msg, nil return msg, nil
} }
func (h *Headscale) getAvailableIP() (*net.IP, error) { func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
i := 0 ipPrefix := h.cfg.IPPrefix
usedIps, err := h.getUsedIPs()
if err != nil {
return nil, err
}
// Get the first IP in our prefix
ip := ipPrefix.IP()
for { for {
ip, err := getRandomIP() if !ipPrefix.Contains(ip) {
if err != nil { return nil, fmt.Errorf("could not find any suitable IP in %s", ipPrefix)
return nil, err
} }
m := Machine{}
if result := h.db.First(&m, "ip_address = ?", ip.String()); errors.Is(result.Error, gorm.ErrRecordNotFound) { // Some OS (including Linux) does not like when IPs ends with 0 or 255, which
return ip, nil // is typically called network or broadcast. Lets avoid them and continue
// to look when we get one of those traditionally reserved IPs.
ipRaw := ip.As4()
if ipRaw[3] == 0 || ipRaw[3] == 255 {
ip = ip.Next()
continue
} }
i++
if i == 100 { // really random number if ip.IsZero() &&
break ip.IsLoopback() {
ip = ip.Next()
continue
} }
if !containsIPs(usedIps, ip) {
return &ip, nil
}
ip = ip.Next()
} }
return nil, errors.New("Could not find an available IP address in 100.64.0.0/10")
} }
func getRandomIP() (*net.IP, error) { func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) {
mathrand.Seed(time.Now().Unix()) var addresses []string
ipo, ipnet, err := net.ParseCIDR("100.64.0.0/10") h.db.Model(&Machine{}).Pluck("ip_address", &addresses)
if err == nil {
ip := ipo.To4() ips := make([]netaddr.IP, len(addresses))
// fmt.Println("In Randomize IPAddr: IP ", ip, " IPNET: ", ipnet) for index, addr := range addresses {
// fmt.Println("Final address is ", ip) if addr != "" {
// fmt.Println("Broadcast address is ", ipb) ip, err := netaddr.ParseIP(addr)
// fmt.Println("Network address is ", ipn) if err != nil {
r := mathrand.Uint32() return nil, fmt.Errorf("failed to parse ip from database, %w", err)
ipRaw := make([]byte, 4) }
binary.LittleEndian.PutUint32(ipRaw, r)
// ipRaw[3] = 254 ips[index] = ip
// fmt.Println("ipRaw is ", ipRaw)
for i, v := range ipRaw {
// fmt.Println("IP Before: ", ip[i], " v is ", v, " Mask is: ", ipnet.Mask[i])
ip[i] = ip[i] + (v &^ ipnet.Mask[i])
// fmt.Println("IP After: ", ip[i])
} }
// fmt.Println("FINAL IP: ", ip.String())
return &ip, nil
} }
return nil, err return ips, nil
}
func containsIPs(ips []netaddr.IP, ip netaddr.IP) bool {
for _, v := range ips {
if v == ip {
return true
}
}
return false
} }

155
utils_test.go Normal file
View File

@@ -0,0 +1,155 @@
package headscale
import (
"gopkg.in/check.v1"
"inet.af/netaddr"
)
func (s *Suite) TestGetAvailableIp(c *check.C) {
ip, err := h.getAvailableIP()
c.Assert(err, check.IsNil)
expected := netaddr.MustParseIP("10.27.0.1")
c.Assert(ip.String(), check.Equals, expected.String())
}
func (s *Suite) TestGetUsedIps(c *check.C) {
ip, err := h.getAvailableIP()
c.Assert(err, check.IsNil)
n, err := h.CreateNamespace("test_ip")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)
m := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
IPAddress: ip.String(),
}
h.db.Save(&m)
ips, err := h.getUsedIPs()
c.Assert(err, check.IsNil)
expected := netaddr.MustParseIP("10.27.0.1")
c.Assert(ips[0], check.Equals, expected)
m1, err := h.GetMachineByID(0)
c.Assert(err, check.IsNil)
c.Assert(m1.IPAddress, check.Equals, expected.String())
}
func (s *Suite) TestGetMultiIp(c *check.C) {
n, err := h.CreateNamespace("test-ip-multi")
c.Assert(err, check.IsNil)
for i := 1; i <= 350; i++ {
ip, err := h.getAvailableIP()
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)
m := Machine{
ID: uint64(i),
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
IPAddress: ip.String(),
}
h.db.Save(&m)
}
ips, err := h.getUsedIPs()
c.Assert(err, check.IsNil)
c.Assert(len(ips), check.Equals, 350)
c.Assert(ips[0], check.Equals, netaddr.MustParseIP("10.27.0.1"))
c.Assert(ips[9], check.Equals, netaddr.MustParseIP("10.27.0.10"))
c.Assert(ips[300], check.Equals, netaddr.MustParseIP("10.27.1.47"))
// Check that we can read back the IPs
m1, err := h.GetMachineByID(1)
c.Assert(err, check.IsNil)
c.Assert(m1.IPAddress, check.Equals, netaddr.MustParseIP("10.27.0.1").String())
m50, err := h.GetMachineByID(50)
c.Assert(err, check.IsNil)
c.Assert(m50.IPAddress, check.Equals, netaddr.MustParseIP("10.27.0.50").String())
expectedNextIP := netaddr.MustParseIP("10.27.1.97")
nextIP, err := h.getAvailableIP()
c.Assert(err, check.IsNil)
c.Assert(nextIP.String(), check.Equals, expectedNextIP.String())
// If we call get Available again, we should receive
// the same IP, as it has not been reserved.
nextIP2, err := h.getAvailableIP()
c.Assert(err, check.IsNil)
c.Assert(nextIP2.String(), check.Equals, expectedNextIP.String())
}
func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
ip, err := h.getAvailableIP()
c.Assert(err, check.IsNil)
expected := netaddr.MustParseIP("10.27.0.1")
c.Assert(ip.String(), check.Equals, expected.String())
n, err := h.CreateNamespace("test_ip")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)
m := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Name: "testmachine",
NamespaceID: n.ID,
Registered: true,
RegisterMethod: "authKey",
AuthKeyID: uint(pak.ID),
}
h.db.Save(&m)
ip2, err := h.getAvailableIP()
c.Assert(err, check.IsNil)
c.Assert(ip2.String(), check.Equals, expected.String())
}