mirror of
https://github.com/juanfont/headscale.git
synced 2025-08-17 02:49:35 +00:00
Compare commits
28 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
8a614dabc0 | ||
![]() |
c95cf15731 | ||
![]() |
e7ce902f9d | ||
![]() |
d421c7b665 | ||
![]() |
1abc68ccf4 | ||
![]() |
575b15e5fa | ||
![]() |
a8c8a358d0 | ||
![]() |
cd2ca137c0 | ||
![]() |
0660867a16 | ||
![]() |
b1200140b8 | ||
![]() |
d10b57b317 | ||
![]() |
42bf566fff | ||
![]() |
0bb2fabc6c | ||
![]() |
ee704f8ef3 | ||
![]() |
4aad3b7933 | ||
![]() |
6091373b53 | ||
![]() |
3879120967 | ||
![]() |
465669f650 | ||
![]() |
ea615e3a26 | ||
![]() |
d3349aa4d1 | ||
![]() |
73207decfd | ||
![]() |
eda6e560c3 | ||
![]() |
95de823b72 | ||
![]() |
9f85efffd5 | ||
![]() |
b5841c8a8b | ||
![]() |
309f868a21 | ||
![]() |
461a893ee4 | ||
![]() |
97f7c90092 |
@@ -24,6 +24,7 @@ Headscale implements this coordination server.
|
||||
- [x] Node registration via pre-auth keys (including reusable keys, and ephemeral node support)
|
||||
- [X] JSON-formatted output
|
||||
- [X] ACLs
|
||||
- [X] Support for alternative IP ranges in the tailnets (default Tailscale's 100.64.0.0/10)
|
||||
- [ ] Share nodes between ~~users~~ namespaces
|
||||
- [ ] DNS
|
||||
|
||||
@@ -113,9 +114,15 @@ Headscale's configuration file is named `config.json` or `config.yaml`. Headscal
|
||||
```
|
||||
"server_url": "http://192.168.1.12:8080",
|
||||
"listen_addr": "0.0.0.0:8080",
|
||||
"ip_prefix": "100.64.0.0/10"
|
||||
```
|
||||
|
||||
`server_url` is the external URL via which Headscale is reachable. `listen_addr` is the IP address and port the Headscale program should listen on.
|
||||
`server_url` is the external URL via which Headscale is reachable. `listen_addr` is the IP address and port the Headscale program should listen on. `ip_prefix` is the IP prefix (range) in which IP addresses for nodes will be allocated (default 100.64.0.0/10, e.g., 192.168.4.0/24, 10.0.0.0/8)
|
||||
|
||||
```
|
||||
"log_level": "debug"
|
||||
```
|
||||
`log_level` can be used to set the Log level for Headscale, it defaults to `debug`, and the available levels are: `trace`, `debug`, `info`, `warn` and `error`.
|
||||
|
||||
```
|
||||
"private_key_path": "private.key",
|
||||
|
9
acls.go
9
acls.go
@@ -4,11 +4,12 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/tailscale/hujson"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
@@ -66,7 +67,8 @@ func (h *Headscale) generateACLRules() (*[]tailcfg.FilterRule, error) {
|
||||
for j, u := range a.Users {
|
||||
srcs, err := h.generateACLPolicySrcIP(u)
|
||||
if err != nil {
|
||||
log.Printf("Error parsing ACL %d, User %d", i, j)
|
||||
log.Error().
|
||||
Msgf("Error parsing ACL %d, User %d", i, j)
|
||||
return nil, err
|
||||
}
|
||||
srcIPs = append(srcIPs, *srcs...)
|
||||
@@ -77,7 +79,8 @@ func (h *Headscale) generateACLRules() (*[]tailcfg.FilterRule, error) {
|
||||
for j, d := range a.Ports {
|
||||
dests, err := h.generateACLPolicyDestPorts(d)
|
||||
if err != nil {
|
||||
log.Printf("Error parsing ACL %d, Port %d", i, j)
|
||||
log.Error().
|
||||
Msgf("Error parsing ACL %d, Port %d", i, j)
|
||||
return nil, err
|
||||
}
|
||||
destPorts = append(destPorts, *dests...)
|
||||
|
284
api.go
284
api.go
@@ -6,10 +6,11 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"gorm.io/datatypes"
|
||||
@@ -63,21 +64,27 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||
mKeyStr := c.Param("id")
|
||||
mKey, err := wgkey.ParseHex(mKeyStr)
|
||||
if err != nil {
|
||||
log.Printf("Cannot parse machine key: %s", err)
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Cannot parse machine key")
|
||||
c.String(http.StatusInternalServerError, "Sad!")
|
||||
return
|
||||
}
|
||||
req := tailcfg.RegisterRequest{}
|
||||
err = decode(body, &req, &mKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Printf("Cannot decode message: %s", err)
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Cannot decode message")
|
||||
c.String(http.StatusInternalServerError, "Very sad!")
|
||||
return
|
||||
}
|
||||
|
||||
var m Machine
|
||||
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
log.Println("New Machine!")
|
||||
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
|
||||
m = Machine{
|
||||
Expiry: &req.Expiry,
|
||||
MachineKey: mKey.HexString(),
|
||||
@@ -85,7 +92,10 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||
NodeKey: wgkey.Key(req.NodeKey).HexString(),
|
||||
}
|
||||
if err := h.db.Create(&m).Error; err != nil {
|
||||
log.Printf("Could not create row: %s", err)
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Could not create row")
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -100,13 +110,20 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||
// We have the updated key!
|
||||
if m.NodeKey == wgkey.Key(req.NodeKey).HexString() {
|
||||
if m.Registered {
|
||||
log.Printf("[%s] Client is registered and we have the current NodeKey. All clear to /map", m.Name)
|
||||
log.Debug().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", m.Name).
|
||||
Msg("Client is registered and we have the current NodeKey. All clear to /map")
|
||||
|
||||
resp.AuthURL = ""
|
||||
resp.MachineAuthorized = true
|
||||
resp.User = *m.Namespace.toUser()
|
||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Printf("Cannot encode message: %s", err)
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
c.String(http.StatusInternalServerError, "")
|
||||
return
|
||||
}
|
||||
@@ -114,12 +131,18 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[%s] Not registered and not NodeKey rotation. Sending a authurl to register", m.Name)
|
||||
log.Debug().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", m.Name).
|
||||
Msg("Not registered and not NodeKey rotation. Sending a authurl to register")
|
||||
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
||||
h.cfg.ServerURL, mKey.HexString())
|
||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Printf("Cannot encode message: %s", err)
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
c.String(http.StatusInternalServerError, "")
|
||||
return
|
||||
}
|
||||
@@ -129,7 +152,10 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||
|
||||
// The NodeKey we have matches OldNodeKey, which means this is a refresh after an key expiration
|
||||
if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() {
|
||||
log.Printf("[%s] We have the OldNodeKey in the database. This is a key refresh", m.Name)
|
||||
log.Debug().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", m.Name).
|
||||
Msg("We have the OldNodeKey in the database. This is a key refresh")
|
||||
m.NodeKey = wgkey.Key(req.NodeKey).HexString()
|
||||
h.db.Save(&m)
|
||||
|
||||
@@ -137,7 +163,10 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||
resp.User = *m.Namespace.toUser()
|
||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Printf("Cannot encode message: %s", err)
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
c.String(http.StatusInternalServerError, "Extremely sad!")
|
||||
return
|
||||
}
|
||||
@@ -148,25 +177,38 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||
// We arrive here after a client is restarted without finalizing the authentication flow or
|
||||
// when headscale is stopped in the middle of the auth process.
|
||||
if m.Registered {
|
||||
log.Printf("[%s] The node is sending us a new NodeKey, but machine is registered. All clear for /map", m.Name)
|
||||
log.Debug().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", m.Name).
|
||||
Msg("The node is sending us a new NodeKey, but machine is registered. All clear for /map")
|
||||
resp.AuthURL = ""
|
||||
resp.MachineAuthorized = true
|
||||
resp.User = *m.Namespace.toUser()
|
||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Printf("Cannot encode message: %s", err)
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
c.String(http.StatusInternalServerError, "")
|
||||
return
|
||||
}
|
||||
c.Data(200, "application/json; charset=utf-8", respBody)
|
||||
return
|
||||
}
|
||||
log.Printf("[%s] The node is sending us a new NodeKey, sending auth url", m.Name)
|
||||
|
||||
log.Debug().
|
||||
Str("handler", "Registration").
|
||||
Str("machine", m.Name).
|
||||
Msg("The node is sending us a new NodeKey, sending auth url")
|
||||
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
||||
h.cfg.ServerURL, mKey.HexString())
|
||||
respBody, err := encode(resp, &mKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Printf("Cannot encode message: %s", err)
|
||||
log.Error().
|
||||
Str("handler", "Registration").
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
c.String(http.StatusInternalServerError, "")
|
||||
return
|
||||
}
|
||||
@@ -183,28 +225,45 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||
//
|
||||
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
|
||||
func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("id", c.Param("id")).
|
||||
Msg("PollNetMapHandler called")
|
||||
body, _ := io.ReadAll(c.Request.Body)
|
||||
mKeyStr := c.Param("id")
|
||||
mKey, err := wgkey.ParseHex(mKeyStr)
|
||||
if err != nil {
|
||||
log.Printf("Cannot parse client key: %s", err)
|
||||
log.Error().
|
||||
Str("handler", "PollNetMap").
|
||||
Err(err).
|
||||
Msg("Cannot parse client key")
|
||||
c.String(http.StatusBadRequest, "")
|
||||
return
|
||||
}
|
||||
req := tailcfg.MapRequest{}
|
||||
err = decode(body, &req, &mKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Printf("Cannot decode message: %s", err)
|
||||
log.Error().
|
||||
Str("handler", "PollNetMap").
|
||||
Err(err).
|
||||
Msg("Cannot decode message")
|
||||
c.String(http.StatusBadRequest, "")
|
||||
return
|
||||
}
|
||||
|
||||
var m Machine
|
||||
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
log.Printf("Ignoring request, cannot find machine with key %s", mKey.HexString())
|
||||
log.Warn().
|
||||
Str("handler", "PollNetMap").
|
||||
Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString())
|
||||
c.String(http.StatusUnauthorized, "")
|
||||
return
|
||||
}
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("id", c.Param("id")).
|
||||
Str("machine", m.Name).
|
||||
Msg("Found machine in database")
|
||||
|
||||
hostinfo, _ := json.Marshal(req.Hostinfo)
|
||||
m.Name = req.Hostinfo.Hostname
|
||||
@@ -227,17 +286,34 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
||||
}
|
||||
h.db.Save(&m)
|
||||
|
||||
pollData := make(chan []byte, 1)
|
||||
update := make(chan []byte, 1)
|
||||
cancelKeepAlive := make(chan []byte, 1)
|
||||
|
||||
pollData := make(chan []byte, 1)
|
||||
defer close(pollData)
|
||||
|
||||
cancelKeepAlive := make(chan []byte, 1)
|
||||
defer close(cancelKeepAlive)
|
||||
h.pollMu.Lock()
|
||||
h.clientsPolling[m.ID] = update
|
||||
h.pollMu.Unlock()
|
||||
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("id", c.Param("id")).
|
||||
Str("machine", m.Name).
|
||||
Msg("Locking poll mutex")
|
||||
h.clientsPolling.Store(m.ID, update)
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("id", c.Param("id")).
|
||||
Str("machine", m.Name).
|
||||
Msg("Unlocking poll mutex")
|
||||
|
||||
data, err := h.getMapResponse(mKey, req, m)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("id", c.Param("id")).
|
||||
Str("machine", m.Name).
|
||||
Err(err).
|
||||
Msg("Failed to get Map response")
|
||||
c.String(http.StatusInternalServerError, ":(")
|
||||
return
|
||||
}
|
||||
@@ -247,50 +323,90 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
||||
// empty endpoints to peers)
|
||||
|
||||
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
|
||||
log.Printf("[%s] ReadOnly=%t OmitPeers=%t Stream=%t", m.Name, req.ReadOnly, req.OmitPeers, req.Stream)
|
||||
log.Debug().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("id", c.Param("id")).
|
||||
Str("machine", m.Name).
|
||||
Bool("readOnly", req.ReadOnly).
|
||||
Bool("omitPeers", req.OmitPeers).
|
||||
Bool("stream", req.Stream).
|
||||
Msg("Client map request processed")
|
||||
|
||||
if req.ReadOnly {
|
||||
log.Printf("[%s] Client is starting up. Asking for DERP map", m.Name)
|
||||
log.Info().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("machine", m.Name).
|
||||
Msg("Client is starting up. Asking for DERP map")
|
||||
c.Data(200, "application/json; charset=utf-8", *data)
|
||||
return
|
||||
}
|
||||
if req.OmitPeers && !req.Stream {
|
||||
log.Printf("[%s] Client sent endpoint update and is ok with a response without peer list", m.Name)
|
||||
log.Info().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("machine", m.Name).
|
||||
Msg("Client sent endpoint update and is ok with a response without peer list")
|
||||
c.Data(200, "application/json; charset=utf-8", *data)
|
||||
return
|
||||
} else if req.OmitPeers && req.Stream {
|
||||
log.Printf("[%s] Warning, ignoring request, don't know how to handle it", m.Name)
|
||||
log.Warn().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("machine", m.Name).
|
||||
Msg("Ignoring request, don't know how to handle it")
|
||||
c.String(http.StatusBadRequest, "")
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[%s] Client is ready to access the tailnet", m.Name)
|
||||
log.Printf("[%s] Sending initial map", m.Name)
|
||||
log.Info().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("machine", m.Name).
|
||||
Msg("Client is ready to access the tailnet")
|
||||
log.Info().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("machine", m.Name).
|
||||
Msg("Sending initial map")
|
||||
pollData <- *data
|
||||
|
||||
log.Printf("[%s] Notifying peers", m.Name)
|
||||
log.Info().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("machine", m.Name).
|
||||
Msg("Notifying peers")
|
||||
peers, _ := h.getPeers(m)
|
||||
h.pollMu.Lock()
|
||||
for _, p := range *peers {
|
||||
pUp, ok := h.clientsPolling[uint64(p.ID)]
|
||||
pUp, ok := h.clientsPolling.Load(uint64(p.ID))
|
||||
if ok {
|
||||
log.Printf("[%s] Notifying peer %s (%s)", m.Name, p.Name, p.Addresses[0])
|
||||
pUp <- []byte{}
|
||||
log.Info().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("machine", m.Name).
|
||||
Str("peer", m.Name).
|
||||
Str("address", p.Addresses[0].String()).
|
||||
Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0])
|
||||
pUp.(chan []byte) <- []byte{}
|
||||
} else {
|
||||
log.Printf("[%s] Peer %s does not appear to be polling", m.Name, p.Name)
|
||||
log.Info().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("machine", m.Name).
|
||||
Str("peer", m.Name).
|
||||
Msgf("Peer %s does not appear to be polling", p.Name)
|
||||
}
|
||||
}
|
||||
h.pollMu.Unlock()
|
||||
|
||||
go h.keepAlive(cancelKeepAlive, pollData, mKey, req, m)
|
||||
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-pollData:
|
||||
log.Printf("[%s] Sending data (%d bytes)", m.Name, len(data))
|
||||
log.Trace().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("machine", m.Name).
|
||||
Int("bytes", len(data)).
|
||||
Msg("Sending data")
|
||||
_, err := w.Write(data)
|
||||
if err != nil {
|
||||
log.Printf("[%s] Cannot write data: %s", m.Name, err)
|
||||
log.Error().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("machine", m.Name).
|
||||
Err(err).
|
||||
Msg("Cannot write data")
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
m.LastSeen = &now
|
||||
@@ -298,27 +414,39 @@ func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
||||
return true
|
||||
|
||||
case <-update:
|
||||
log.Printf("[%s] Received a request for update", m.Name)
|
||||
log.Debug().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("machine", m.Name).
|
||||
Msg("Received a request for update")
|
||||
data, err := h.getMapResponse(mKey, req, m)
|
||||
if err != nil {
|
||||
log.Printf("[%s] Could not get the map update: %s", m.Name, err)
|
||||
log.Error().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("machine", m.Name).
|
||||
Err(err).
|
||||
Msg("Could not get the map update")
|
||||
}
|
||||
_, err = w.Write(*data)
|
||||
if err != nil {
|
||||
log.Printf("[%s] Could not write the map response: %s", m.Name, err)
|
||||
log.Error().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("machine", m.Name).
|
||||
Err(err).
|
||||
Msg("Could not write the map response")
|
||||
}
|
||||
return true
|
||||
|
||||
case <-c.Request.Context().Done():
|
||||
log.Printf("[%s] The client has closed the connection", m.Name)
|
||||
log.Info().
|
||||
Str("handler", "PollNetMap").
|
||||
Str("machine", m.Name).
|
||||
Msg("The client has closed the connection")
|
||||
now := time.Now().UTC()
|
||||
m.LastSeen = &now
|
||||
h.db.Save(&m)
|
||||
h.pollMu.Lock()
|
||||
cancelKeepAlive <- []byte{}
|
||||
delete(h.clientsPolling, m.ID)
|
||||
h.clientsPolling.Delete(m.ID)
|
||||
close(update)
|
||||
h.pollMu.Unlock()
|
||||
return false
|
||||
|
||||
}
|
||||
@@ -335,10 +463,16 @@ func (h *Headscale) keepAlive(cancel chan []byte, pollData chan []byte, mKey wgk
|
||||
h.pollMu.Lock()
|
||||
data, err := h.getMapKeepAliveResponse(mKey, req, m)
|
||||
if err != nil {
|
||||
log.Printf("Error generating the keep alive msg: %s", err)
|
||||
log.Error().
|
||||
Str("func", "keepAlive").
|
||||
Err(err).
|
||||
Msg("Error generating the keep alive msg")
|
||||
return
|
||||
}
|
||||
log.Printf("[%s] Sending keepalive", m.Name)
|
||||
log.Debug().
|
||||
Str("func", "keepAlive").
|
||||
Str("machine", m.Name).
|
||||
Msg("Sending keepalive")
|
||||
pollData <- *data
|
||||
h.pollMu.Unlock()
|
||||
time.Sleep(60 * time.Second)
|
||||
@@ -347,14 +481,24 @@ func (h *Headscale) keepAlive(cancel chan []byte, pollData chan []byte, mKey wgk
|
||||
}
|
||||
|
||||
func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m Machine) (*[]byte, error) {
|
||||
log.Trace().
|
||||
Str("func", "getMapResponse").
|
||||
Str("machine", req.Hostinfo.Hostname).
|
||||
Msg("Creating Map response")
|
||||
node, err := m.toNode()
|
||||
if err != nil {
|
||||
log.Printf("Cannot convert to node: %s", err)
|
||||
log.Error().
|
||||
Str("func", "getMapResponse").
|
||||
Err(err).
|
||||
Msg("Cannot convert to node")
|
||||
return nil, err
|
||||
}
|
||||
peers, err := h.getPeers(m)
|
||||
if err != nil {
|
||||
log.Printf("Cannot fetch peers: %s", err)
|
||||
log.Error().
|
||||
Str("func", "getMapResponse").
|
||||
Err(err).
|
||||
Msg("Cannot fetch peers")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -426,25 +570,49 @@ func (h *Headscale) getMapKeepAliveResponse(mKey wgkey.Key, req tailcfg.MapReque
|
||||
}
|
||||
|
||||
func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key, req tailcfg.RegisterRequest, m Machine) {
|
||||
log.Debug().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", req.Hostinfo.Hostname).
|
||||
Msgf("Processing auth key for %s", req.Hostinfo.Hostname)
|
||||
resp := tailcfg.RegisterResponse{}
|
||||
pak, err := h.checkKeyValidity(req.Auth.AuthKey)
|
||||
if err != nil {
|
||||
resp.MachineAuthorized = false
|
||||
respBody, err := encode(resp, &idKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Printf("Cannot encode message: %s", err)
|
||||
log.Error().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", m.Name).
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
c.String(http.StatusInternalServerError, "")
|
||||
return
|
||||
}
|
||||
c.Data(200, "application/json; charset=utf-8", respBody)
|
||||
log.Printf("[%s] Failed authentication via AuthKey", m.Name)
|
||||
log.Error().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", m.Name).
|
||||
Msg("Failed authentication via AuthKey")
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", m.Name).
|
||||
Msg("Authentication key was valid, proceeding to acquire an IP address")
|
||||
ip, err := h.getAvailableIP()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
log.Error().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", m.Name).
|
||||
Msg("Failed to find an available IP")
|
||||
return
|
||||
}
|
||||
log.Info().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", m.Name).
|
||||
Str("ip", ip.String()).
|
||||
Msgf("Assining %s to %s", ip, m.Name)
|
||||
|
||||
m.AuthKeyID = uint(pak.ID)
|
||||
m.IPAddress = ip.String()
|
||||
@@ -458,10 +626,18 @@ func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key,
|
||||
resp.User = *pak.Namespace.toUser()
|
||||
respBody, err := encode(resp, &idKey, h.privateKey)
|
||||
if err != nil {
|
||||
log.Printf("Cannot encode message: %s", err)
|
||||
log.Error().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", m.Name).
|
||||
Err(err).
|
||||
Msg("Cannot encode message")
|
||||
c.String(http.StatusInternalServerError, "Extremely sad!")
|
||||
return
|
||||
}
|
||||
c.Data(200, "application/json; charset=utf-8", respBody)
|
||||
log.Printf("[%s] Successfully authenticated via AuthKey", m.Name)
|
||||
log.Info().
|
||||
Str("func", "handleAuthKey").
|
||||
Str("machine", m.Name).
|
||||
Str("ip", ip.String()).
|
||||
Msg("Successfully authenticated via AuthKey")
|
||||
}
|
||||
|
44
app.go
44
app.go
@@ -3,16 +3,18 @@ package headscale
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
"gorm.io/gorm"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/wgkey"
|
||||
)
|
||||
@@ -24,6 +26,7 @@ type Config struct {
|
||||
PrivateKeyPath string
|
||||
DerpMap *tailcfg.DERPMap
|
||||
EphemeralNodeInactivityTimeout time.Duration
|
||||
IPPrefix netaddr.IPPrefix
|
||||
|
||||
DBtype string
|
||||
DBpath string
|
||||
@@ -56,7 +59,7 @@ type Headscale struct {
|
||||
aclRules *[]tailcfg.FilterRule
|
||||
|
||||
pollMu sync.Mutex
|
||||
clientsPolling map[uint64]chan []byte // this is by all means a hackity hack
|
||||
clientsPolling sync.Map
|
||||
}
|
||||
|
||||
// NewHeadscale returns the Headscale app
|
||||
@@ -96,7 +99,6 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
h.clientsPolling = make(map[uint64]chan []byte)
|
||||
return &h, nil
|
||||
}
|
||||
|
||||
@@ -118,27 +120,41 @@ func (h *Headscale) ExpireEphemeralNodes(milliSeconds int64) {
|
||||
func (h *Headscale) expireEphemeralNodesWorker() {
|
||||
namespaces, err := h.ListNamespaces()
|
||||
if err != nil {
|
||||
log.Printf("Error listing namespaces: %s", err)
|
||||
log.Error().Err(err).Msg("Error listing namespaces")
|
||||
return
|
||||
}
|
||||
for _, ns := range *namespaces {
|
||||
machines, err := h.ListMachinesInNamespace(ns.Name)
|
||||
if err != nil {
|
||||
log.Printf("Error listing machines in namespace %s: %s", ns.Name, err)
|
||||
log.Error().Err(err).Str("namespace", ns.Name).Msg("Error listing machines in namespace")
|
||||
return
|
||||
}
|
||||
for _, m := range *machines {
|
||||
if m.AuthKey != nil && m.LastSeen != nil && m.AuthKey.Ephemeral && time.Now().After(m.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
|
||||
log.Printf("[%s] Ephemeral client removed from database\n", m.Name)
|
||||
log.Info().Str("machine", m.Name).Msg("Ephemeral client removed from database")
|
||||
err = h.db.Unscoped().Delete(m).Error
|
||||
if err != nil {
|
||||
log.Printf("[%s] 🤮 Cannot delete ephemeral machine from the database: %s", m.Name, err)
|
||||
log.Error().Err(err).Str("machine", m.Name).Msg("🤮 Cannot delete ephemeral machine from the database")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WatchForKVUpdates checks the KV DB table for requests to perform tailnet upgrades
|
||||
// This is a way to communitate the CLI with the headscale server
|
||||
func (h *Headscale) watchForKVUpdates(milliSeconds int64) {
|
||||
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
||||
for range ticker.C {
|
||||
h.watchForKVUpdatesWorker()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) watchForKVUpdatesWorker() {
|
||||
h.checkForNamespacesPendingUpdates()
|
||||
// more functions will come here in the future
|
||||
}
|
||||
|
||||
// Serve launches a GIN server with the Headscale API
|
||||
func (h *Headscale) Serve() error {
|
||||
r := gin.Default()
|
||||
@@ -147,9 +163,12 @@ func (h *Headscale) Serve() error {
|
||||
r.POST("/machine/:id/map", h.PollNetMapHandler)
|
||||
r.POST("/machine/:id", h.RegistrationHandler)
|
||||
var err error
|
||||
|
||||
go h.watchForKVUpdates(5000)
|
||||
|
||||
if h.cfg.TLSLetsEncryptHostname != "" {
|
||||
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
|
||||
log.Println("WARNING: listening with TLS but ServerURL does not start with https://")
|
||||
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
|
||||
}
|
||||
|
||||
m := autocert.Manager{
|
||||
@@ -172,7 +191,10 @@ func (h *Headscale) Serve() error {
|
||||
// port 80 for the certificate validation in addition to the headscale
|
||||
// service, which can be configured to run on any other port.
|
||||
go func() {
|
||||
log.Fatal(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, m.HTTPHandler(http.HandlerFunc(h.redirect))))
|
||||
|
||||
log.Fatal().
|
||||
Err(http.ListenAndServe(h.cfg.TLSLetsEncryptListen, m.HTTPHandler(http.HandlerFunc(h.redirect)))).
|
||||
Msg("failed to set up a HTTP server")
|
||||
}()
|
||||
err = s.ListenAndServeTLS("", "")
|
||||
} else {
|
||||
@@ -180,12 +202,12 @@ func (h *Headscale) Serve() error {
|
||||
}
|
||||
} else if h.cfg.TLSCertPath == "" {
|
||||
if !strings.HasPrefix(h.cfg.ServerURL, "http://") {
|
||||
log.Println("WARNING: listening without TLS but ServerURL does not start with http://")
|
||||
log.Warn().Msg("Listening without TLS but ServerURL does not start with http://")
|
||||
}
|
||||
err = r.Run(h.cfg.Addr)
|
||||
} else {
|
||||
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
|
||||
log.Println("WARNING: listening with TLS but ServerURL does not start with https://")
|
||||
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
|
||||
}
|
||||
err = r.RunTLS(h.cfg.Addr, h.cfg.TLSCertPath, h.cfg.TLSKeyPath)
|
||||
}
|
||||
|
@@ -6,6 +6,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"gopkg.in/check.v1"
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
func Test(t *testing.T) {
|
||||
@@ -36,7 +37,9 @@ func (s *Suite) ResetDB(c *check.C) {
|
||||
if err != nil {
|
||||
c.Fatal(err)
|
||||
}
|
||||
cfg := Config{}
|
||||
cfg := Config{
|
||||
IPPrefix: netaddr.MustParseIPPrefix("10.27.0.0/23"),
|
||||
}
|
||||
|
||||
h = Headscale{
|
||||
cfg: cfg,
|
||||
|
@@ -15,6 +15,7 @@ func (s *Suite) TestRegisterMachine(c *check.C) {
|
||||
DiscoKey: "faa",
|
||||
Name: "testmachine",
|
||||
NamespaceID: n.ID,
|
||||
IPAddress: "10.0.0.1",
|
||||
}
|
||||
h.db.Save(&m)
|
||||
|
||||
|
@@ -2,8 +2,9 @@ package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/spf13/cobra"
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -16,8 +17,7 @@ var rootCmd = &cobra.Command{
|
||||
Long: `
|
||||
headscale is an open source implementation of the Tailscale control server
|
||||
|
||||
Juan Font Alonso <juanfontalonso@gmail.com> - 2021
|
||||
https://gitlab.com/juanfont/headscale`,
|
||||
https://github.com/juanfont/headscale`,
|
||||
}
|
||||
|
||||
func Execute() {
|
||||
|
@@ -5,15 +5,16 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/spf13/viper"
|
||||
"gopkg.in/yaml.v2"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
@@ -36,6 +37,10 @@ func LoadConfig(path string) error {
|
||||
viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache")
|
||||
viper.SetDefault("tls_letsencrypt_challenge_type", "HTTP-01")
|
||||
|
||||
viper.SetDefault("ip_prefix", "100.64.0.0/10")
|
||||
|
||||
viper.SetDefault("log_level", "debug")
|
||||
|
||||
err := viper.ReadInConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Fatal error reading config file: %s \n", err)
|
||||
@@ -49,7 +54,8 @@ func LoadConfig(path string) error {
|
||||
|
||||
if (viper.GetString("tls_letsencrypt_hostname") != "") && (viper.GetString("tls_letsencrypt_challenge_type") == "TLS-ALPN-01") && (!strings.HasSuffix(viper.GetString("listen_addr"), ":443")) {
|
||||
// this is only a warning because there could be something sitting in front of headscale that redirects the traffic (e.g. an iptables rule)
|
||||
log.Println("Warning: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, headscale must be reachable on port 443, i.e. listen_addr should probably end in :443")
|
||||
log.Warn().
|
||||
Msg("Warning: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, headscale must be reachable on port 443, i.e. listen_addr should probably end in :443")
|
||||
}
|
||||
|
||||
if (viper.GetString("tls_letsencrypt_challenge_type") != "HTTP-01") && (viper.GetString("tls_letsencrypt_challenge_type") != "TLS-ALPN-01") {
|
||||
@@ -79,9 +85,13 @@ func absPath(path string) string {
|
||||
}
|
||||
|
||||
func getHeadscaleApp() (*headscale.Headscale, error) {
|
||||
derpMap, err := loadDerpMap(absPath(viper.GetString("derp_map_path")))
|
||||
derpPath := absPath(viper.GetString("derp_map_path"))
|
||||
derpMap, err := loadDerpMap(derpPath)
|
||||
if err != nil {
|
||||
log.Printf("Could not load DERP servers map file: %s", err)
|
||||
log.Error().
|
||||
Str("path", derpPath).
|
||||
Err(err).
|
||||
Msg("Could not load DERP servers map file")
|
||||
}
|
||||
|
||||
// Minimum inactivity time out is keepalive timeout (60s) plus a few seconds
|
||||
@@ -97,6 +107,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
||||
Addr: viper.GetString("listen_addr"),
|
||||
PrivateKeyPath: absPath(viper.GetString("private_key_path")),
|
||||
DerpMap: derpMap,
|
||||
IPPrefix: netaddr.MustParseIPPrefix(viper.GetString("ip_prefix")),
|
||||
|
||||
EphemeralNodeInactivityTimeout: viper.GetDuration("ephemeral_node_inactivity_timeout"),
|
||||
|
||||
@@ -125,9 +136,13 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
||||
// We are doing this here, as in the future could be cool to have it also hot-reload
|
||||
|
||||
if viper.GetString("acl_policy_path") != "" {
|
||||
err = h.LoadACLPolicy(absPath(viper.GetString("acl_policy_path")))
|
||||
aclPath := absPath(viper.GetString("acl_policy_path"))
|
||||
err = h.LoadACLPolicy(aclPath)
|
||||
if err != nil {
|
||||
log.Printf("Could not load the ACL policy: %s", err)
|
||||
log.Error().
|
||||
Str("path", aclPath).
|
||||
Err(err).
|
||||
Msg("Could not load the ACL policy")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,24 +172,24 @@ func JsonOutput(result interface{}, errResult error, outputFormat string) {
|
||||
if errResult != nil {
|
||||
j, err = json.MarshalIndent(ErrorOutput{errResult.Error()}, "", "\t")
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
log.Fatal().Err(err)
|
||||
}
|
||||
} else {
|
||||
j, err = json.MarshalIndent(result, "", "\t")
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
log.Fatal().Err(err)
|
||||
}
|
||||
}
|
||||
case "json-line":
|
||||
if errResult != nil {
|
||||
j, err = json.Marshal(ErrorOutput{errResult.Error()})
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
log.Fatal().Err(err)
|
||||
}
|
||||
} else {
|
||||
j, err = json.Marshal(result)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
log.Fatal().Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -1,15 +1,41 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/cmd/headscale/cli"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
|
||||
log.Logger = log.Output(zerolog.ConsoleWriter{
|
||||
Out: os.Stdout,
|
||||
TimeFormat: time.RFC3339,
|
||||
})
|
||||
|
||||
err := cli.LoadConfig("")
|
||||
if err != nil {
|
||||
log.Fatalf(err.Error())
|
||||
log.Fatal().Err(err)
|
||||
}
|
||||
|
||||
logLevel := viper.GetString("log_level")
|
||||
switch logLevel {
|
||||
case "trace":
|
||||
zerolog.SetGlobalLevel(zerolog.TraceLevel)
|
||||
case "debug":
|
||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
case "info":
|
||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||
case "warn":
|
||||
zerolog.SetGlobalLevel(zerolog.WarnLevel)
|
||||
case "error":
|
||||
zerolog.SetGlobalLevel(zerolog.ErrorLevel)
|
||||
default:
|
||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
}
|
||||
|
||||
cli.Execute()
|
||||
|
2
db.go
2
db.go
@@ -79,6 +79,7 @@ func (h *Headscale) openDB() (*gorm.DB, error) {
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// getValue returns the value for the given key in KV
|
||||
func (h *Headscale) getValue(key string) (string, error) {
|
||||
var row KV
|
||||
if result := h.db.First(&row, "key = ?", key); errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
@@ -87,6 +88,7 @@ func (h *Headscale) getValue(key string) (string, error) {
|
||||
return row.Value, nil
|
||||
}
|
||||
|
||||
// setValue sets value for the given key in KV
|
||||
func (h *Headscale) setValue(key string, value string) error {
|
||||
kv := KV{
|
||||
Key: key,
|
||||
|
1
go.mod
1
go.mod
@@ -9,6 +9,7 @@ require (
|
||||
github.com/klauspost/compress v1.13.1
|
||||
github.com/lib/pq v1.10.2 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.7 // indirect
|
||||
github.com/rs/zerolog v1.23.0 // indirect
|
||||
github.com/spf13/cobra v1.1.3
|
||||
github.com/spf13/viper v1.8.1
|
||||
github.com/tailscale/hujson v0.0.0-20200924210142-dde312d0d6a2
|
||||
|
2
go.sum
2
go.sum
@@ -683,6 +683,8 @@ github.com/rogpeppe/go-internal v1.6.2/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTE
|
||||
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
|
||||
github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
|
||||
github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc=
|
||||
github.com/rs/zerolog v1.23.0 h1:UskrK+saS9P9Y789yNNulYKdARjPZuS35B8gJF2x60g=
|
||||
github.com/rs/zerolog v1.23.0/go.mod h1:6c7hFfxPOy7TacJc4Fcdi24/J0NKYGzjG8FWRI916Qo=
|
||||
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/ryancurrah/gomodguard v1.1.0/go.mod h1:4O8tr7hBODaGE6VIhfJDHcwzh5GUccKSJBU0UMXJFVM=
|
||||
github.com/ryanrolds/sqlclosecheck v0.3.0/go.mod h1:1gREqxyTGR3lVtpngyFo3hZAgk0KCtEdgEkHwDbigdA=
|
||||
|
@@ -65,7 +65,6 @@ tasks like creating namespaces, authkeys, etc.
|
||||
|
||||
headscale is an open source implementation of the Tailscale control server
|
||||
|
||||
Juan Font Alonso <juanfontalonso@gmail.com> - 2021
|
||||
https://gitlab.com/juanfont/headscale
|
||||
|
||||
Usage:
|
||||
|
12
machine.go
12
machine.go
@@ -3,11 +3,12 @@ package headscale
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"gorm.io/datatypes"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
@@ -157,7 +158,7 @@ func (h *Headscale) getPeers(m Machine) (*[]*tailcfg.Node, error) {
|
||||
machines := []Machine{}
|
||||
if err := h.db.Where("namespace_id = ? AND machine_key <> ? AND registered",
|
||||
m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil {
|
||||
log.Printf("Error accessing db: %s", err)
|
||||
log.Error().Err(err).Msg("Error accessing db")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -200,19 +201,22 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
|
||||
// DeleteMachine softs deletes a Machine from the database
|
||||
func (h *Headscale) DeleteMachine(m *Machine) error {
|
||||
m.Registered = false
|
||||
namespaceID := m.NamespaceID
|
||||
h.db.Save(&m) // we mark it as unregistered, just in case
|
||||
if err := h.db.Delete(&m).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
return h.RequestMapUpdates(namespaceID)
|
||||
}
|
||||
|
||||
// HardDeleteMachine hard deletes a Machine from the database
|
||||
func (h *Headscale) HardDeleteMachine(m *Machine) error {
|
||||
namespaceID := m.NamespaceID
|
||||
if err := h.db.Unscoped().Delete(&m).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return h.RequestMapUpdates(namespaceID)
|
||||
}
|
||||
|
||||
// GetHostInfo returns a Hostinfo struct for the machine
|
||||
|
@@ -1,6 +1,8 @@
|
||||
package headscale
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"gopkg.in/check.v1"
|
||||
)
|
||||
|
||||
@@ -81,6 +83,15 @@ func (s *Suite) TestDeleteMachine(c *check.C) {
|
||||
h.db.Save(&m)
|
||||
err = h.DeleteMachine(&m)
|
||||
c.Assert(err, check.IsNil)
|
||||
v, err := h.getValue("namespaces_pending_updates")
|
||||
c.Assert(err, check.IsNil)
|
||||
names := []string{}
|
||||
err = json.Unmarshal([]byte(v), &names)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(names, check.DeepEquals, []string{n.Name})
|
||||
h.checkForNamespacesPendingUpdates()
|
||||
v, _ = h.getValue("namespaces_pending_updates")
|
||||
c.Assert(v, check.Equals, "")
|
||||
_, err = h.GetMachine(n.Name, "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
}
|
||||
|
107
namespaces.go
107
namespaces.go
@@ -1,10 +1,12 @@
|
||||
package headscale
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
@@ -31,7 +33,10 @@ func (h *Headscale) CreateNamespace(name string) (*Namespace, error) {
|
||||
}
|
||||
n.Name = name
|
||||
if err := h.db.Create(&n).Error; err != nil {
|
||||
log.Printf("Could not create row: %s", err)
|
||||
log.Error().
|
||||
Str("func", "CreateNamespace").
|
||||
Err(err).
|
||||
Msg("Could not create row")
|
||||
return nil, err
|
||||
}
|
||||
return &n, nil
|
||||
@@ -103,6 +108,104 @@ func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// RequestMapUpdates signals the KV worker to update the maps for this namespace
|
||||
func (h *Headscale) RequestMapUpdates(namespaceID uint) error {
|
||||
namespace := Namespace{}
|
||||
if err := h.db.First(&namespace, namespaceID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v, err := h.getValue("namespaces_pending_updates")
|
||||
if err != nil || v == "" {
|
||||
err = h.setValue("namespaces_pending_updates", fmt.Sprintf(`["%s"]`, namespace.Name))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
names := []string{}
|
||||
err = json.Unmarshal([]byte(v), &names)
|
||||
if err != nil {
|
||||
err = h.setValue("namespaces_pending_updates", fmt.Sprintf(`["%s"]`, namespace.Name))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
names = append(names, namespace.Name)
|
||||
data, err := json.Marshal(names)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("func", "RequestMapUpdates").
|
||||
Err(err).
|
||||
Msg("Could not marshal namespaces_pending_updates")
|
||||
return err
|
||||
}
|
||||
return h.setValue("namespaces_pending_updates", string(data))
|
||||
}
|
||||
|
||||
func (h *Headscale) checkForNamespacesPendingUpdates() {
|
||||
v, err := h.getValue("namespaces_pending_updates")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if v == "" {
|
||||
return
|
||||
}
|
||||
|
||||
names := []string{}
|
||||
err = json.Unmarshal([]byte(v), &names)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, name := range names {
|
||||
log.Trace().
|
||||
Str("func", "RequestMapUpdates").
|
||||
Str("machine", name).
|
||||
Msg("Sending updates to nodes in namespace")
|
||||
machines, err := h.ListMachinesInNamespace(name)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, m := range *machines {
|
||||
peers, _ := h.getPeers(m)
|
||||
for _, p := range *peers {
|
||||
pUp, ok := h.clientsPolling.Load(uint64(p.ID))
|
||||
if ok {
|
||||
log.Info().
|
||||
Str("func", "checkForNamespacesPendingUpdates").
|
||||
Str("machine", m.Name).
|
||||
Str("peer", m.Name).
|
||||
Str("address", p.Addresses[0].String()).
|
||||
Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0])
|
||||
pUp.(chan []byte) <- []byte{}
|
||||
} else {
|
||||
log.Info().
|
||||
Str("func", "checkForNamespacesPendingUpdates").
|
||||
Str("machine", m.Name).
|
||||
Str("peer", m.Name).
|
||||
Msgf("Peer %s does not appear to be polling", p.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
newV, err := h.getValue("namespaces_pending_updates")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if v == newV { // only clear when no changes, so we notified everybody
|
||||
err = h.setValue("namespaces_pending_updates", "")
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("func", "checkForNamespacesPendingUpdates").
|
||||
Err(err).
|
||||
Msg("Could not save to KV")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Namespace) toUser() *tailcfg.User {
|
||||
u := tailcfg.User{
|
||||
ID: tailcfg.UserID(n.ID),
|
||||
|
@@ -52,8 +52,8 @@ func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr
|
||||
peers, _ := h.getPeers(*m)
|
||||
h.pollMu.Lock()
|
||||
for _, p := range *peers {
|
||||
if pUp, ok := h.clientsPolling[uint64(p.ID)]; ok {
|
||||
pUp <- []byte{}
|
||||
if pUp, ok := h.clientsPolling.Load(uint64(p.ID)); ok {
|
||||
pUp.(chan []byte) <- []byte{}
|
||||
}
|
||||
}
|
||||
h.pollMu.Unlock()
|
||||
|
98
utils.go
98
utils.go
@@ -7,18 +7,12 @@ package headscale
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
mathrand "math/rand"
|
||||
|
||||
"golang.org/x/crypto/nacl/box"
|
||||
"gorm.io/gorm"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/types/wgkey"
|
||||
)
|
||||
|
||||
@@ -77,47 +71,71 @@ func encodeMsg(b []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, err
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (h *Headscale) getAvailableIP() (*net.IP, error) {
|
||||
i := 0
|
||||
func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
|
||||
ipPrefix := h.cfg.IPPrefix
|
||||
|
||||
usedIps, err := h.getUsedIPs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the first IP in our prefix
|
||||
ip := ipPrefix.IP()
|
||||
|
||||
for {
|
||||
ip, err := getRandomIP()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if !ipPrefix.Contains(ip) {
|
||||
return nil, fmt.Errorf("could not find any suitable IP in %s", ipPrefix)
|
||||
}
|
||||
m := Machine{}
|
||||
if result := h.db.First(&m, "ip_address = ?", ip.String()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return ip, nil
|
||||
|
||||
// Some OS (including Linux) does not like when IPs ends with 0 or 255, which
|
||||
// is typically called network or broadcast. Lets avoid them and continue
|
||||
// to look when we get one of those traditionally reserved IPs.
|
||||
ipRaw := ip.As4()
|
||||
if ipRaw[3] == 0 || ipRaw[3] == 255 {
|
||||
ip = ip.Next()
|
||||
continue
|
||||
}
|
||||
i++
|
||||
if i == 100 { // really random number
|
||||
break
|
||||
|
||||
if ip.IsZero() &&
|
||||
ip.IsLoopback() {
|
||||
|
||||
ip = ip.Next()
|
||||
continue
|
||||
}
|
||||
|
||||
if !containsIPs(usedIps, ip) {
|
||||
return &ip, nil
|
||||
}
|
||||
|
||||
ip = ip.Next()
|
||||
}
|
||||
return nil, errors.New("Could not find an available IP address in 100.64.0.0/10")
|
||||
}
|
||||
|
||||
func getRandomIP() (*net.IP, error) {
|
||||
mathrand.Seed(time.Now().Unix())
|
||||
ipo, ipnet, err := net.ParseCIDR("100.64.0.0/10")
|
||||
if err == nil {
|
||||
ip := ipo.To4()
|
||||
// fmt.Println("In Randomize IPAddr: IP ", ip, " IPNET: ", ipnet)
|
||||
// fmt.Println("Final address is ", ip)
|
||||
// fmt.Println("Broadcast address is ", ipb)
|
||||
// fmt.Println("Network address is ", ipn)
|
||||
r := mathrand.Uint32()
|
||||
ipRaw := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(ipRaw, r)
|
||||
// ipRaw[3] = 254
|
||||
// fmt.Println("ipRaw is ", ipRaw)
|
||||
for i, v := range ipRaw {
|
||||
// fmt.Println("IP Before: ", ip[i], " v is ", v, " Mask is: ", ipnet.Mask[i])
|
||||
ip[i] = ip[i] + (v &^ ipnet.Mask[i])
|
||||
// fmt.Println("IP After: ", ip[i])
|
||||
func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) {
|
||||
var addresses []string
|
||||
h.db.Model(&Machine{}).Pluck("ip_address", &addresses)
|
||||
|
||||
ips := make([]netaddr.IP, len(addresses))
|
||||
for index, addr := range addresses {
|
||||
if addr != "" {
|
||||
ip, err := netaddr.ParseIP(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse ip from database, %w", err)
|
||||
}
|
||||
|
||||
ips[index] = ip
|
||||
}
|
||||
// fmt.Println("FINAL IP: ", ip.String())
|
||||
return &ip, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
func containsIPs(ips []netaddr.IP, ip netaddr.IP) bool {
|
||||
for _, v := range ips {
|
||||
if v == ip {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
155
utils_test.go
Normal file
155
utils_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package headscale
|
||||
|
||||
import (
|
||||
"gopkg.in/check.v1"
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
func (s *Suite) TestGetAvailableIp(c *check.C) {
|
||||
ip, err := h.getAvailableIP()
|
||||
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
expected := netaddr.MustParseIP("10.27.0.1")
|
||||
|
||||
c.Assert(ip.String(), check.Equals, expected.String())
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetUsedIps(c *check.C) {
|
||||
ip, err := h.getAvailableIP()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
n, err := h.CreateNamespace("test_ip")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = h.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
m := Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Name: "testmachine",
|
||||
NamespaceID: n.ID,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
AuthKeyID: uint(pak.ID),
|
||||
IPAddress: ip.String(),
|
||||
}
|
||||
h.db.Save(&m)
|
||||
|
||||
ips, err := h.getUsedIPs()
|
||||
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
expected := netaddr.MustParseIP("10.27.0.1")
|
||||
|
||||
c.Assert(ips[0], check.Equals, expected)
|
||||
|
||||
m1, err := h.GetMachineByID(0)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(m1.IPAddress, check.Equals, expected.String())
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetMultiIp(c *check.C) {
|
||||
n, err := h.CreateNamespace("test-ip-multi")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
for i := 1; i <= 350; i++ {
|
||||
ip, err := h.getAvailableIP()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = h.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
m := Machine{
|
||||
ID: uint64(i),
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Name: "testmachine",
|
||||
NamespaceID: n.ID,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
AuthKeyID: uint(pak.ID),
|
||||
IPAddress: ip.String(),
|
||||
}
|
||||
h.db.Save(&m)
|
||||
}
|
||||
|
||||
ips, err := h.getUsedIPs()
|
||||
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(ips), check.Equals, 350)
|
||||
|
||||
c.Assert(ips[0], check.Equals, netaddr.MustParseIP("10.27.0.1"))
|
||||
c.Assert(ips[9], check.Equals, netaddr.MustParseIP("10.27.0.10"))
|
||||
c.Assert(ips[300], check.Equals, netaddr.MustParseIP("10.27.1.47"))
|
||||
|
||||
// Check that we can read back the IPs
|
||||
m1, err := h.GetMachineByID(1)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(m1.IPAddress, check.Equals, netaddr.MustParseIP("10.27.0.1").String())
|
||||
|
||||
m50, err := h.GetMachineByID(50)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(m50.IPAddress, check.Equals, netaddr.MustParseIP("10.27.0.50").String())
|
||||
|
||||
expectedNextIP := netaddr.MustParseIP("10.27.1.97")
|
||||
nextIP, err := h.getAvailableIP()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(nextIP.String(), check.Equals, expectedNextIP.String())
|
||||
|
||||
// If we call get Available again, we should receive
|
||||
// the same IP, as it has not been reserved.
|
||||
nextIP2, err := h.getAvailableIP()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(nextIP2.String(), check.Equals, expectedNextIP.String())
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
|
||||
ip, err := h.getAvailableIP()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
expected := netaddr.MustParseIP("10.27.0.1")
|
||||
|
||||
c.Assert(ip.String(), check.Equals, expected.String())
|
||||
|
||||
n, err := h.CreateNamespace("test_ip")
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = h.GetMachine("test", "testmachine")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
m := Machine{
|
||||
ID: 0,
|
||||
MachineKey: "foo",
|
||||
NodeKey: "bar",
|
||||
DiscoKey: "faa",
|
||||
Name: "testmachine",
|
||||
NamespaceID: n.ID,
|
||||
Registered: true,
|
||||
RegisterMethod: "authKey",
|
||||
AuthKeyID: uint(pak.ID),
|
||||
}
|
||||
h.db.Save(&m)
|
||||
|
||||
ip2, err := h.getAvailableIP()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(ip2.String(), check.Equals, expected.String())
|
||||
}
|
Reference in New Issue
Block a user