mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-30 13:35:23 +00:00
Merge branch 'main' into topic/docker-release
This commit is contained in:
commit
1e93347a26
41
.github/workflows/lint.yml
vendored
Normal file
41
.github/workflows/lint.yml
vendored
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on: [push, pull_request]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
# The "build" workflow
|
||||||
|
lint:
|
||||||
|
# The type of runner that the job will run on
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
# Steps represent a sequence of tasks that will be executed as part of the job
|
||||||
|
steps:
|
||||||
|
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
|
||||||
|
# Install and run golangci-lint as a separate step, it's much faster this
|
||||||
|
# way because this action has caching. It'll get run again in `make lint`
|
||||||
|
# below, but it's still much faster in the end than installing
|
||||||
|
# golangci-lint manually in the `Run lint` step.
|
||||||
|
- uses: golangci/golangci-lint-action@v2
|
||||||
|
with:
|
||||||
|
args: --timeout 4m
|
||||||
|
|
||||||
|
# Setup Go
|
||||||
|
- name: Setup Go
|
||||||
|
uses: actions/setup-go@v2
|
||||||
|
with:
|
||||||
|
go-version: "1.16.3" # The Go version to download (if necessary) and use.
|
||||||
|
|
||||||
|
# Install all the dependencies
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
go version
|
||||||
|
go install golang.org/x/lint/golint@latest
|
||||||
|
sudo apt update
|
||||||
|
sudo apt install -y make
|
||||||
|
|
||||||
|
- name: Run lint
|
||||||
|
with:
|
||||||
|
args: --timeout 4m
|
||||||
|
run: make lint
|
23
.github/workflows/test-integration.yml
vendored
Normal file
23
.github/workflows/test-integration.yml
vendored
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on: [pull_request]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
# The "build" workflow
|
||||||
|
integration-test:
|
||||||
|
# The type of runner that the job will run on
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
# Steps represent a sequence of tasks that will be executed as part of the job
|
||||||
|
steps:
|
||||||
|
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
|
||||||
|
# Setup Go
|
||||||
|
- name: Setup Go
|
||||||
|
uses: actions/setup-go@v2
|
||||||
|
with:
|
||||||
|
go-version: "1.16.3"
|
||||||
|
|
||||||
|
- name: Run Integration tests
|
||||||
|
run: go test -tags integration -timeout 30m
|
14
.github/workflows/test.yml
vendored
14
.github/workflows/test.yml
vendored
@ -13,33 +13,21 @@ jobs:
|
|||||||
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
|
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
|
|
||||||
# Install and run golangci-lint as a separate step, it's much faster this
|
|
||||||
# way because this action has caching. It'll get run again in `make lint`
|
|
||||||
# below, but it's still much faster in the end than installing
|
|
||||||
# golangci-lint manually in the `Run lint` step.
|
|
||||||
- uses: golangci/golangci-lint-action@v2
|
|
||||||
with:
|
|
||||||
args: --timeout 2m
|
|
||||||
|
|
||||||
# Setup Go
|
# Setup Go
|
||||||
- name: Setup Go
|
- name: Setup Go
|
||||||
uses: actions/setup-go@v2
|
uses: actions/setup-go@v2
|
||||||
with:
|
with:
|
||||||
go-version: '1.16.3' # The Go version to download (if necessary) and use.
|
go-version: "1.16.3" # The Go version to download (if necessary) and use.
|
||||||
|
|
||||||
# Install all the dependencies
|
# Install all the dependencies
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
go version
|
go version
|
||||||
go install golang.org/x/lint/golint@latest
|
|
||||||
sudo apt update
|
sudo apt update
|
||||||
sudo apt install -y make
|
sudo apt install -y make
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: make test
|
run: make test
|
||||||
|
|
||||||
- name: Run lint
|
|
||||||
run: make lint
|
|
||||||
|
|
||||||
- name: Run build
|
- name: Run build
|
||||||
run: make
|
run: make
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -22,3 +22,5 @@ config.json
|
|||||||
|
|
||||||
# Exclude Jetbrains Editors
|
# Exclude Jetbrains Editors
|
||||||
.idea
|
.idea
|
||||||
|
|
||||||
|
test_output/
|
||||||
|
@ -10,7 +10,7 @@ COPY . /go/src/headscale
|
|||||||
RUN go install -a -ldflags="-extldflags=-static" -tags netgo,sqlite_omit_load_extension ./cmd/headscale
|
RUN go install -a -ldflags="-extldflags=-static" -tags netgo,sqlite_omit_load_extension ./cmd/headscale
|
||||||
RUN test -e /go/bin/headscale
|
RUN test -e /go/bin/headscale
|
||||||
|
|
||||||
FROM ubuntu:latest
|
FROM ubuntu:20.04
|
||||||
|
|
||||||
COPY --from=build /go/bin/headscale /usr/local/bin/headscale
|
COPY --from=build /go/bin/headscale /usr/local/bin/headscale
|
||||||
ENV TZ UTC
|
ENV TZ UTC
|
||||||
|
2
Makefile
2
Makefile
@ -2,7 +2,7 @@
|
|||||||
version = $(shell ./scripts/version-at-commit.sh)
|
version = $(shell ./scripts/version-at-commit.sh)
|
||||||
|
|
||||||
build:
|
build:
|
||||||
go build -ldflags "-s -w -X main.version=$(version)" cmd/headscale/headscale.go
|
go build -ldflags "-s -w -X github.com/juanfont/headscale/cmd/headscale/cli.version=$(version)" cmd/headscale/headscale.go
|
||||||
|
|
||||||
dev: lint test build
|
dev: lint test build
|
||||||
|
|
||||||
|
@ -25,14 +25,13 @@ Headscale implements this coordination server.
|
|||||||
- [X] JSON-formatted output
|
- [X] JSON-formatted output
|
||||||
- [X] ACLs
|
- [X] ACLs
|
||||||
- [X] Support for alternative IP ranges in the tailnets (default Tailscale's 100.64.0.0/10)
|
- [X] Support for alternative IP ranges in the tailnets (default Tailscale's 100.64.0.0/10)
|
||||||
- [ ] Share nodes between ~~users~~ namespaces
|
- [X] DNS (passing DNS servers to nodes)
|
||||||
- [ ] DNS
|
- [X] Share nodes between ~~users~~ namespaces
|
||||||
|
- [ ] MagicDNS / Smart DNS
|
||||||
|
|
||||||
|
|
||||||
## Roadmap 🤷
|
## Roadmap 🤷
|
||||||
|
|
||||||
We are now focusing on adding integration tests with the official clients.
|
|
||||||
|
|
||||||
Suggestions/PRs welcomed!
|
Suggestions/PRs welcomed!
|
||||||
|
|
||||||
|
|
||||||
|
208
api.go
208
api.go
@ -13,9 +13,7 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/klauspost/compress/zstd"
|
"github.com/klauspost/compress/zstd"
|
||||||
"gorm.io/datatypes"
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"inet.af/netaddr"
|
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/wgkey"
|
"tailscale.com/types/wgkey"
|
||||||
)
|
)
|
||||||
@ -35,8 +33,6 @@ func (h *Headscale) RegisterWebAPI(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// spew.Dump(c.Params)
|
|
||||||
|
|
||||||
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
||||||
<html>
|
<html>
|
||||||
<body>
|
<body>
|
||||||
@ -82,6 +78,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
var m Machine
|
var m Machine
|
||||||
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
|
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
|
||||||
@ -90,6 +87,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
|||||||
MachineKey: mKey.HexString(),
|
MachineKey: mKey.HexString(),
|
||||||
Name: req.Hostinfo.Hostname,
|
Name: req.Hostinfo.Hostname,
|
||||||
NodeKey: wgkey.Key(req.NodeKey).HexString(),
|
NodeKey: wgkey.Key(req.NodeKey).HexString(),
|
||||||
|
LastSuccessfulUpdate: &now,
|
||||||
}
|
}
|
||||||
if err := h.db.Create(&m).Error; err != nil {
|
if err := h.db.Create(&m).Error; err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
@ -215,202 +213,12 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
|||||||
c.Data(200, "application/json; charset=utf-8", respBody)
|
c.Data(200, "application/json; charset=utf-8", respBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PollNetMapHandler takes care of /machine/:id/map
|
|
||||||
//
|
|
||||||
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
|
|
||||||
// the clients when something in the network changes.
|
|
||||||
//
|
|
||||||
// The clients POST stuff like HostInfo and their Endpoints here, but
|
|
||||||
// only after their first request (marked with the ReadOnly field).
|
|
||||||
//
|
|
||||||
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
|
|
||||||
func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
|
||||||
log.Trace().
|
|
||||||
Str("handler", "PollNetMap").
|
|
||||||
Str("id", c.Param("id")).
|
|
||||||
Msg("PollNetMapHandler called")
|
|
||||||
body, _ := io.ReadAll(c.Request.Body)
|
|
||||||
mKeyStr := c.Param("id")
|
|
||||||
mKey, err := wgkey.ParseHex(mKeyStr)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Str("handler", "PollNetMap").
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot parse client key")
|
|
||||||
c.String(http.StatusBadRequest, "")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
req := tailcfg.MapRequest{}
|
|
||||||
err = decode(body, &req, &mKey, h.privateKey)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Str("handler", "PollNetMap").
|
|
||||||
Err(err).
|
|
||||||
Msg("Cannot decode message")
|
|
||||||
c.String(http.StatusBadRequest, "")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var m Machine
|
|
||||||
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
||||||
log.Warn().
|
|
||||||
Str("handler", "PollNetMap").
|
|
||||||
Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString())
|
|
||||||
c.String(http.StatusUnauthorized, "")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Trace().
|
|
||||||
Str("handler", "PollNetMap").
|
|
||||||
Str("id", c.Param("id")).
|
|
||||||
Str("machine", m.Name).
|
|
||||||
Msg("Found machine in database")
|
|
||||||
|
|
||||||
hostinfo, _ := json.Marshal(req.Hostinfo)
|
|
||||||
m.Name = req.Hostinfo.Hostname
|
|
||||||
m.HostInfo = datatypes.JSON(hostinfo)
|
|
||||||
m.DiscoKey = wgkey.Key(req.DiscoKey).HexString()
|
|
||||||
now := time.Now().UTC()
|
|
||||||
|
|
||||||
// From Tailscale client:
|
|
||||||
//
|
|
||||||
// ReadOnly is whether the client just wants to fetch the MapResponse,
|
|
||||||
// without updating their Endpoints. The Endpoints field will be ignored and
|
|
||||||
// LastSeen will not be updated and peers will not be notified of changes.
|
|
||||||
//
|
|
||||||
// The intended use is for clients to discover the DERP map at start-up
|
|
||||||
// before their first real endpoint update.
|
|
||||||
if !req.ReadOnly {
|
|
||||||
endpoints, _ := json.Marshal(req.Endpoints)
|
|
||||||
m.Endpoints = datatypes.JSON(endpoints)
|
|
||||||
m.LastSeen = &now
|
|
||||||
}
|
|
||||||
h.db.Save(&m)
|
|
||||||
|
|
||||||
data, err := h.getMapResponse(mKey, req, m)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Str("handler", "PollNetMap").
|
|
||||||
Str("id", c.Param("id")).
|
|
||||||
Str("machine", m.Name).
|
|
||||||
Err(err).
|
|
||||||
Msg("Failed to get Map response")
|
|
||||||
c.String(http.StatusInternalServerError, ":(")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// We update our peers if the client is not sending ReadOnly in the MapRequest
|
|
||||||
// so we don't distribute its initial request (it comes with
|
|
||||||
// empty endpoints to peers)
|
|
||||||
|
|
||||||
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
|
|
||||||
log.Debug().
|
|
||||||
Str("handler", "PollNetMap").
|
|
||||||
Str("id", c.Param("id")).
|
|
||||||
Str("machine", m.Name).
|
|
||||||
Bool("readOnly", req.ReadOnly).
|
|
||||||
Bool("omitPeers", req.OmitPeers).
|
|
||||||
Bool("stream", req.Stream).
|
|
||||||
Msg("Client map request processed")
|
|
||||||
|
|
||||||
if req.ReadOnly {
|
|
||||||
log.Info().
|
|
||||||
Str("handler", "PollNetMap").
|
|
||||||
Str("machine", m.Name).
|
|
||||||
Msg("Client is starting up. Asking for DERP map")
|
|
||||||
c.Data(200, "application/json; charset=utf-8", *data)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if req.OmitPeers && !req.Stream {
|
|
||||||
log.Info().
|
|
||||||
Str("handler", "PollNetMap").
|
|
||||||
Str("machine", m.Name).
|
|
||||||
Msg("Client sent endpoint update and is ok with a response without peer list")
|
|
||||||
c.Data(200, "application/json; charset=utf-8", *data)
|
|
||||||
return
|
|
||||||
} else if req.OmitPeers && req.Stream {
|
|
||||||
log.Warn().
|
|
||||||
Str("handler", "PollNetMap").
|
|
||||||
Str("machine", m.Name).
|
|
||||||
Msg("Ignoring request, don't know how to handle it")
|
|
||||||
c.String(http.StatusBadRequest, "")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only create update channel if it has not been created
|
|
||||||
var update chan []byte
|
|
||||||
log.Trace().
|
|
||||||
Str("handler", "PollNetMap").
|
|
||||||
Str("id", c.Param("id")).
|
|
||||||
Str("machine", m.Name).
|
|
||||||
Msg("Creating or loading update channel")
|
|
||||||
if result, ok := h.clientsPolling.LoadOrStore(m.ID, make(chan []byte, 1)); ok {
|
|
||||||
update = result.(chan []byte)
|
|
||||||
}
|
|
||||||
|
|
||||||
pollData := make(chan []byte, 1)
|
|
||||||
defer close(pollData)
|
|
||||||
|
|
||||||
cancelKeepAlive := make(chan []byte, 1)
|
|
||||||
defer close(cancelKeepAlive)
|
|
||||||
|
|
||||||
log.Info().
|
|
||||||
Str("handler", "PollNetMap").
|
|
||||||
Str("machine", m.Name).
|
|
||||||
Msg("Client is ready to access the tailnet")
|
|
||||||
log.Info().
|
|
||||||
Str("handler", "PollNetMap").
|
|
||||||
Str("machine", m.Name).
|
|
||||||
Msg("Sending initial map")
|
|
||||||
pollData <- *data
|
|
||||||
|
|
||||||
log.Info().
|
|
||||||
Str("handler", "PollNetMap").
|
|
||||||
Str("machine", m.Name).
|
|
||||||
Msg("Notifying peers")
|
|
||||||
// TODO: Why does this block?
|
|
||||||
go h.notifyChangesToPeers(&m)
|
|
||||||
|
|
||||||
h.PollNetMapStream(c, m, req, mKey, pollData, update, cancelKeepAlive)
|
|
||||||
log.Trace().
|
|
||||||
Str("handler", "PollNetMap").
|
|
||||||
Str("id", c.Param("id")).
|
|
||||||
Str("machine", m.Name).
|
|
||||||
Msg("Finished stream, closing PollNetMap session")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Headscale) keepAlive(cancel chan []byte, pollData chan []byte, mKey wgkey.Key, req tailcfg.MapRequest, m Machine) {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-cancel:
|
|
||||||
return
|
|
||||||
|
|
||||||
default:
|
|
||||||
data, err := h.getMapKeepAliveResponse(mKey, req, m)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().
|
|
||||||
Str("func", "keepAlive").
|
|
||||||
Err(err).
|
|
||||||
Msg("Error generating the keep alive msg")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().
|
|
||||||
Str("func", "keepAlive").
|
|
||||||
Str("machine", m.Name).
|
|
||||||
Msg("Sending keepalive")
|
|
||||||
pollData <- *data
|
|
||||||
|
|
||||||
time.Sleep(60 * time.Second)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m Machine) (*[]byte, error) {
|
func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m Machine) (*[]byte, error) {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("func", "getMapResponse").
|
Str("func", "getMapResponse").
|
||||||
Str("machine", req.Hostinfo.Hostname).
|
Str("machine", req.Hostinfo.Hostname).
|
||||||
Msg("Creating Map response")
|
Msg("Creating Map response")
|
||||||
node, err := m.toNode()
|
node, err := m.toNode(true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("func", "getMapResponse").
|
Str("func", "getMapResponse").
|
||||||
@ -437,7 +245,12 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m Mac
|
|||||||
KeepAlive: false,
|
KeepAlive: false,
|
||||||
Node: node,
|
Node: node,
|
||||||
Peers: *peers,
|
Peers: *peers,
|
||||||
DNS: []netaddr.IP{},
|
//TODO(kradalby): As per tailscale docs, if DNSConfig is nil,
|
||||||
|
// it means its not updated, maybe we can have some logic
|
||||||
|
// to check and only pass updates when its updates.
|
||||||
|
// This is probably more relevant if we try to implement
|
||||||
|
// "MagicDNS"
|
||||||
|
DNSConfig: h.cfg.DNSConfig,
|
||||||
SearchPaths: []string{},
|
SearchPaths: []string{},
|
||||||
Domain: "headscale.net",
|
Domain: "headscale.net",
|
||||||
PacketFilter: *h.aclRules,
|
PacketFilter: *h.aclRules,
|
||||||
@ -465,7 +278,6 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m Mac
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// spew.Dump(resp)
|
|
||||||
// declare the incoming size on the first 4 bytes
|
// declare the incoming size on the first 4 bytes
|
||||||
data := make([]byte, 4)
|
data := make([]byte, 4)
|
||||||
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
|
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
|
||||||
@ -542,7 +354,7 @@ func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key,
|
|||||||
Str("func", "handleAuthKey").
|
Str("func", "handleAuthKey").
|
||||||
Str("machine", m.Name).
|
Str("machine", m.Name).
|
||||||
Str("ip", ip.String()).
|
Str("ip", ip.String()).
|
||||||
Msgf("Assining %s to %s", ip, m.Name)
|
Msgf("Assigning %s to %s", ip, m.Name)
|
||||||
|
|
||||||
m.AuthKeyID = uint(pak.ID)
|
m.AuthKeyID = uint(pak.ID)
|
||||||
m.IPAddress = ip.String()
|
m.IPAddress = ip.String()
|
||||||
|
39
app.go
39
app.go
@ -43,6 +43,8 @@ type Config struct {
|
|||||||
|
|
||||||
TLSCertPath string
|
TLSCertPath string
|
||||||
TLSKeyPath string
|
TLSKeyPath string
|
||||||
|
|
||||||
|
DNSConfig *tailcfg.DNSConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
// Headscale represents the base app of the service
|
// Headscale represents the base app of the service
|
||||||
@ -58,7 +60,10 @@ type Headscale struct {
|
|||||||
aclPolicy *ACLPolicy
|
aclPolicy *ACLPolicy
|
||||||
aclRules *[]tailcfg.FilterRule
|
aclRules *[]tailcfg.FilterRule
|
||||||
|
|
||||||
clientsPolling sync.Map
|
clientsUpdateChannels sync.Map
|
||||||
|
clientsUpdateChannelMutex sync.Mutex
|
||||||
|
|
||||||
|
lastStateChange sync.Map
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHeadscale returns the Headscale app
|
// NewHeadscale returns the Headscale app
|
||||||
@ -165,9 +170,18 @@ func (h *Headscale) Serve() error {
|
|||||||
r.POST("/machine/:id", h.RegistrationHandler)
|
r.POST("/machine/:id", h.RegistrationHandler)
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
timeout := 30 * time.Second
|
||||||
|
|
||||||
go h.watchForKVUpdates(5000)
|
go h.watchForKVUpdates(5000)
|
||||||
go h.expireEphemeralNodes(5000)
|
go h.expireEphemeralNodes(5000)
|
||||||
|
|
||||||
|
s := &http.Server{
|
||||||
|
Addr: h.cfg.Addr,
|
||||||
|
Handler: r,
|
||||||
|
ReadTimeout: timeout,
|
||||||
|
WriteTimeout: timeout,
|
||||||
|
}
|
||||||
|
|
||||||
if h.cfg.TLSLetsEncryptHostname != "" {
|
if h.cfg.TLSLetsEncryptHostname != "" {
|
||||||
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
|
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
|
||||||
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
|
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
|
||||||
@ -182,6 +196,8 @@ func (h *Headscale) Serve() error {
|
|||||||
Addr: h.cfg.Addr,
|
Addr: h.cfg.Addr,
|
||||||
TLSConfig: m.TLSConfig(),
|
TLSConfig: m.TLSConfig(),
|
||||||
Handler: r,
|
Handler: r,
|
||||||
|
ReadTimeout: timeout,
|
||||||
|
WriteTimeout: timeout,
|
||||||
}
|
}
|
||||||
if h.cfg.TLSLetsEncryptChallengeType == "TLS-ALPN-01" {
|
if h.cfg.TLSLetsEncryptChallengeType == "TLS-ALPN-01" {
|
||||||
// Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737)
|
// Configuration via autocert with TLS-ALPN-01 (https://tools.ietf.org/html/rfc8737)
|
||||||
@ -206,12 +222,29 @@ func (h *Headscale) Serve() error {
|
|||||||
if !strings.HasPrefix(h.cfg.ServerURL, "http://") {
|
if !strings.HasPrefix(h.cfg.ServerURL, "http://") {
|
||||||
log.Warn().Msg("Listening without TLS but ServerURL does not start with http://")
|
log.Warn().Msg("Listening without TLS but ServerURL does not start with http://")
|
||||||
}
|
}
|
||||||
err = r.Run(h.cfg.Addr)
|
err = s.ListenAndServe()
|
||||||
} else {
|
} else {
|
||||||
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
|
if !strings.HasPrefix(h.cfg.ServerURL, "https://") {
|
||||||
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
|
log.Warn().Msg("Listening with TLS but ServerURL does not start with https://")
|
||||||
}
|
}
|
||||||
err = r.RunTLS(h.cfg.Addr, h.cfg.TLSCertPath, h.cfg.TLSKeyPath)
|
err = s.ListenAndServeTLS(h.cfg.TLSCertPath, h.cfg.TLSKeyPath)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Headscale) setLastStateChangeToNow(namespace string) {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
h.lastStateChange.Store(namespace, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Headscale) getLastStateChange(namespace string) time.Time {
|
||||||
|
if wrapped, ok := h.lastStateChange.Load(namespace); ok {
|
||||||
|
lastChange, _ := wrapped.(time.Time)
|
||||||
|
return lastChange
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
h.lastStateChange.Store(namespace, now)
|
||||||
|
return now
|
||||||
|
}
|
||||||
|
@ -25,6 +25,7 @@ func init() {
|
|||||||
nodeCmd.AddCommand(listNodesCmd)
|
nodeCmd.AddCommand(listNodesCmd)
|
||||||
nodeCmd.AddCommand(registerNodeCmd)
|
nodeCmd.AddCommand(registerNodeCmd)
|
||||||
nodeCmd.AddCommand(deleteNodeCmd)
|
nodeCmd.AddCommand(deleteNodeCmd)
|
||||||
|
nodeCmd.AddCommand(shareMachineCmd)
|
||||||
}
|
}
|
||||||
|
|
||||||
var nodeCmd = &cobra.Command{
|
var nodeCmd = &cobra.Command{
|
||||||
@ -79,9 +80,26 @@ var listNodesCmd = &cobra.Command{
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Error initializing: %s", err)
|
log.Fatalf("Error initializing: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace, err := h.GetNamespace(n)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error fetching namespace: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
machines, err := h.ListMachinesInNamespace(n)
|
machines, err := h.ListMachinesInNamespace(n)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error fetching machines: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sharedMachines, err := h.ListSharedMachinesInNamespace(n)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error fetching shared machines: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
allMachines := append(*machines, *sharedMachines...)
|
||||||
|
|
||||||
if strings.HasPrefix(o, "json") {
|
if strings.HasPrefix(o, "json") {
|
||||||
JsonOutput(machines, err, o)
|
JsonOutput(allMachines, err, o)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,7 +107,7 @@ var listNodesCmd = &cobra.Command{
|
|||||||
log.Fatalf("Error getting nodes: %s", err)
|
log.Fatalf("Error getting nodes: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
d, err := nodesToPtables(*machines)
|
d, err := nodesToPtables(*namespace, allMachines)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Error converting to table: %s", err)
|
log.Fatalf("Error converting to table: %s", err)
|
||||||
}
|
}
|
||||||
@ -145,31 +163,94 @@ var deleteNodeCmd = &cobra.Command{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func nodesToPtables(m []headscale.Machine) (pterm.TableData, error) {
|
var shareMachineCmd = &cobra.Command{
|
||||||
d := pterm.TableData{{"ID", "Name", "NodeKey", "IP address", "Ephemeral", "Last seen", "Online"}}
|
Use: "share ID namespace",
|
||||||
|
Short: "Shares a node from the current namespace to the specified one",
|
||||||
|
Args: func(cmd *cobra.Command, args []string) error {
|
||||||
|
if len(args) < 2 {
|
||||||
|
return fmt.Errorf("missing parameters")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
namespace, err := cmd.Flags().GetString("namespace")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error getting namespace: %s", err)
|
||||||
|
}
|
||||||
|
output, _ := cmd.Flags().GetString("output")
|
||||||
|
|
||||||
for _, m := range m {
|
h, err := getHeadscaleApp()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error initializing: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = h.GetNamespace(namespace)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error fetching origin namespace: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
destinationNamespace, err := h.GetNamespace(args[1])
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error fetching destination namespace: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := strconv.Atoi(args[0])
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error converting ID to integer: %s", err)
|
||||||
|
}
|
||||||
|
machine, err := h.GetMachineByID(uint64(id))
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error getting node: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = h.AddSharedMachineToNamespace(machine, destinationNamespace)
|
||||||
|
if strings.HasPrefix(output, "json") {
|
||||||
|
JsonOutput(map[string]string{"Result": "Node shared"}, err, output)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error sharing node: %s\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Node shared!")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func nodesToPtables(currentNamespace headscale.Namespace, machines []headscale.Machine) (pterm.TableData, error) {
|
||||||
|
d := pterm.TableData{{"ID", "Name", "NodeKey", "Namespace", "IP address", "Ephemeral", "Last seen", "Online"}}
|
||||||
|
|
||||||
|
for _, machine := range machines {
|
||||||
var ephemeral bool
|
var ephemeral bool
|
||||||
if m.AuthKey != nil && m.AuthKey.Ephemeral {
|
if machine.AuthKey != nil && machine.AuthKey.Ephemeral {
|
||||||
ephemeral = true
|
ephemeral = true
|
||||||
}
|
}
|
||||||
var lastSeen time.Time
|
var lastSeen time.Time
|
||||||
if m.LastSeen != nil {
|
var lastSeenTime string
|
||||||
lastSeen = *m.LastSeen
|
if machine.LastSeen != nil {
|
||||||
|
lastSeen = *machine.LastSeen
|
||||||
|
lastSeenTime = lastSeen.Format("2006-01-02 15:04:05")
|
||||||
}
|
}
|
||||||
nKey, err := wgkey.ParseHex(m.NodeKey)
|
nKey, err := wgkey.ParseHex(machine.NodeKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
nodeKey := tailcfg.NodeKey(nKey)
|
nodeKey := tailcfg.NodeKey(nKey)
|
||||||
|
|
||||||
var online string
|
var online string
|
||||||
if m.LastSeen.After(time.Now().Add(-5 * time.Minute)) { // TODO: Find a better way to reliably show if online
|
if lastSeen.After(time.Now().Add(-5 * time.Minute)) { // TODO: Find a better way to reliably show if online
|
||||||
online = pterm.LightGreen("true")
|
online = pterm.LightGreen("true")
|
||||||
} else {
|
} else {
|
||||||
online = pterm.LightRed("false")
|
online = pterm.LightRed("false")
|
||||||
}
|
}
|
||||||
d = append(d, []string{strconv.FormatUint(m.ID, 10), m.Name, nodeKey.ShortString(), m.IPAddress, strconv.FormatBool(ephemeral), lastSeen.Format("2006-01-02 15:04:05"), online})
|
|
||||||
|
var namespace string
|
||||||
|
if currentNamespace.ID == machine.NamespaceID {
|
||||||
|
namespace = pterm.LightMagenta(machine.Namespace.Name)
|
||||||
|
} else {
|
||||||
|
namespace = pterm.LightYellow(machine.Namespace.Name)
|
||||||
|
}
|
||||||
|
d = append(d, []string{strconv.FormatUint(machine.ID, 10), machine.Name, nodeKey.ShortString(), namespace, machine.IPAddress, strconv.FormatBool(ephemeral), lastSeenTime, online})
|
||||||
}
|
}
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pterm/pterm"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -15,6 +16,9 @@ func init() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf(err.Error())
|
log.Fatalf(err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enableRouteCmd.Flags().BoolP("all", "a", false, "Enable all routes advertised by the node")
|
||||||
|
|
||||||
routesCmd.AddCommand(listRoutesCmd)
|
routesCmd.AddCommand(listRoutesCmd)
|
||||||
routesCmd.AddCommand(enableRouteCmd)
|
routesCmd.AddCommand(enableRouteCmd)
|
||||||
}
|
}
|
||||||
@ -44,19 +48,25 @@ var listRoutesCmd = &cobra.Command{
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Error initializing: %s", err)
|
log.Fatalf("Error initializing: %s", err)
|
||||||
}
|
}
|
||||||
routes, err := h.GetNodeRoutes(n, args[0])
|
|
||||||
|
|
||||||
if strings.HasPrefix(o, "json") {
|
|
||||||
JsonOutput(routes, err, o)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
|
availableRoutes, err := h.GetAdvertisedNodeRoutes(n, args[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println(routes)
|
if strings.HasPrefix(o, "json") {
|
||||||
|
// TODO: Add enable/disabled information to this interface
|
||||||
|
JsonOutput(availableRoutes, err, o)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
d := h.RoutesToPtables(n, args[0], *availableRoutes)
|
||||||
|
|
||||||
|
err = pterm.DefaultTable.WithHasHeader().WithData(d).Render()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -64,25 +74,66 @@ var enableRouteCmd = &cobra.Command{
|
|||||||
Use: "enable node-name route",
|
Use: "enable node-name route",
|
||||||
Short: "Allows exposing a route declared by this node to the rest of the nodes",
|
Short: "Allows exposing a route declared by this node to the rest of the nodes",
|
||||||
Args: func(cmd *cobra.Command, args []string) error {
|
Args: func(cmd *cobra.Command, args []string) error {
|
||||||
|
all, err := cmd.Flags().GetBool("all")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error getting namespace: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if all {
|
||||||
|
if len(args) < 1 {
|
||||||
|
return fmt.Errorf("Missing parameters")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
if len(args) < 2 {
|
if len(args) < 2 {
|
||||||
return fmt.Errorf("Missing parameters")
|
return fmt.Errorf("Missing parameters")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
}
|
||||||
},
|
},
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
n, err := cmd.Flags().GetString("namespace")
|
n, err := cmd.Flags().GetString("namespace")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Error getting namespace: %s", err)
|
log.Fatalf("Error getting namespace: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
o, _ := cmd.Flags().GetString("output")
|
o, _ := cmd.Flags().GetString("output")
|
||||||
|
|
||||||
|
all, err := cmd.Flags().GetBool("all")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error getting namespace: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
h, err := getHeadscaleApp()
|
h, err := getHeadscaleApp()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Error initializing: %s", err)
|
log.Fatalf("Error initializing: %s", err)
|
||||||
}
|
}
|
||||||
route, err := h.EnableNodeRoute(n, args[0], args[1])
|
|
||||||
|
if all {
|
||||||
|
availableRoutes, err := h.GetAdvertisedNodeRoutes(n, args[0])
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, availableRoute := range *availableRoutes {
|
||||||
|
err = h.EnableNodeRoute(n, args[0], availableRoute.String())
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(o, "json") {
|
if strings.HasPrefix(o, "json") {
|
||||||
JsonOutput(route, err, o)
|
JsonOutput(availableRoute, err, o)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Enabled route %s\n", availableRoute)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err = h.EnableNodeRoute(n, args[0], args[1])
|
||||||
|
|
||||||
|
if strings.HasPrefix(o, "json") {
|
||||||
|
JsonOutput(args[1], err, o)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -90,6 +141,7 @@ var enableRouteCmd = &cobra.Command{
|
|||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
fmt.Printf("Enabled route %s\n", route)
|
fmt.Printf("Enabled route %s\n", args[1])
|
||||||
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -39,7 +39,9 @@ func LoadConfig(path string) error {
|
|||||||
|
|
||||||
viper.SetDefault("ip_prefix", "100.64.0.0/10")
|
viper.SetDefault("ip_prefix", "100.64.0.0/10")
|
||||||
|
|
||||||
viper.SetDefault("log_level", "debug")
|
viper.SetDefault("log_level", "info")
|
||||||
|
|
||||||
|
viper.SetDefault("dns_config", nil)
|
||||||
|
|
||||||
err := viper.ReadInConfig()
|
err := viper.ReadInConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -70,6 +72,45 @@ func LoadConfig(path string) error {
|
|||||||
} else {
|
} else {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetDNSConfig() *tailcfg.DNSConfig {
|
||||||
|
if viper.IsSet("dns_config") {
|
||||||
|
dnsConfig := &tailcfg.DNSConfig{}
|
||||||
|
|
||||||
|
if viper.IsSet("dns_config.nameservers") {
|
||||||
|
nameserversStr := viper.GetStringSlice("dns_config.nameservers")
|
||||||
|
|
||||||
|
nameservers := make([]netaddr.IP, len(nameserversStr))
|
||||||
|
resolvers := make([]tailcfg.DNSResolver, len(nameserversStr))
|
||||||
|
|
||||||
|
for index, nameserverStr := range nameserversStr {
|
||||||
|
nameserver, err := netaddr.ParseIP(nameserverStr)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Str("func", "getDNSConfig").
|
||||||
|
Err(err).
|
||||||
|
Msgf("Could not parse nameserver IP: %s", nameserverStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
nameservers[index] = nameserver
|
||||||
|
resolvers[index] = tailcfg.DNSResolver{
|
||||||
|
Addr: nameserver.String(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsConfig.Nameservers = nameservers
|
||||||
|
dnsConfig.Resolvers = resolvers
|
||||||
|
}
|
||||||
|
if viper.IsSet("dns_config.domains") {
|
||||||
|
dnsConfig.Domains = viper.GetStringSlice("dns_config.domains")
|
||||||
|
}
|
||||||
|
|
||||||
|
return dnsConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func absPath(path string) string {
|
func absPath(path string) string {
|
||||||
@ -126,6 +167,8 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
|||||||
|
|
||||||
TLSCertPath: absPath(viper.GetString("tls_cert_path")),
|
TLSCertPath: absPath(viper.GetString("tls_cert_path")),
|
||||||
TLSKeyPath: absPath(viper.GetString("tls_key_path")),
|
TLSKeyPath: absPath(viper.GetString("tls_key_path")),
|
||||||
|
|
||||||
|
DNSConfig: GetDNSConfig(),
|
||||||
}
|
}
|
||||||
|
|
||||||
h, err := headscale.NewHeadscale(cfg)
|
h, err := headscale.NewHeadscale(cfg)
|
||||||
|
@ -58,7 +58,7 @@ func (*Suite) TestPostgresConfigLoading(c *check.C) {
|
|||||||
c.Assert(viper.GetString("db_port"), check.Equals, "5432")
|
c.Assert(viper.GetString("db_port"), check.Equals, "5432")
|
||||||
c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
|
c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
|
||||||
c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
|
c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
|
||||||
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
|
c.Assert(viper.GetStringSlice("dns_config.nameservers")[0], check.Equals, "1.1.1.1")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*Suite) TestSqliteConfigLoading(c *check.C) {
|
func (*Suite) TestSqliteConfigLoading(c *check.C) {
|
||||||
@ -92,6 +92,37 @@ func (*Suite) TestSqliteConfigLoading(c *check.C) {
|
|||||||
c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
|
c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "")
|
||||||
c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
|
c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http")
|
||||||
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
|
c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01")
|
||||||
|
c.Assert(viper.GetStringSlice("dns_config.nameservers")[0], check.Equals, "1.1.1.1")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*Suite) TestDNSConfigLoading(c *check.C) {
|
||||||
|
tmpDir, err := ioutil.TempDir("", "headscale")
|
||||||
|
if err != nil {
|
||||||
|
c.Fatal(err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tmpDir)
|
||||||
|
|
||||||
|
path, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
c.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Symlink the example config file
|
||||||
|
err = os.Symlink(filepath.Clean(path+"/../../config.json.sqlite.example"), filepath.Join(tmpDir, "config.json"))
|
||||||
|
if err != nil {
|
||||||
|
c.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load example config, it should load without validation errors
|
||||||
|
err = cli.LoadConfig(tmpDir)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
dnsConfig := cli.GetDNSConfig()
|
||||||
|
fmt.Println(dnsConfig)
|
||||||
|
|
||||||
|
c.Assert(dnsConfig.Nameservers[0].String(), check.Equals, "1.1.1.1")
|
||||||
|
|
||||||
|
c.Assert(dnsConfig.Resolvers[0].Addr, check.Equals, "1.1.1.1")
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeConfig(c *check.C, tmpDir string, configYaml []byte) {
|
func writeConfig(c *check.C, tmpDir string, configYaml []byte) {
|
||||||
|
@ -16,5 +16,10 @@
|
|||||||
"tls_letsencrypt_challenge_type": "HTTP-01",
|
"tls_letsencrypt_challenge_type": "HTTP-01",
|
||||||
"tls_cert_path": "",
|
"tls_cert_path": "",
|
||||||
"tls_key_path": "",
|
"tls_key_path": "",
|
||||||
"acl_policy_path": ""
|
"acl_policy_path": "",
|
||||||
|
"dns_config": {
|
||||||
|
"nameservers": [
|
||||||
|
"1.1.1.1"
|
||||||
|
]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -12,5 +12,10 @@
|
|||||||
"tls_letsencrypt_challenge_type": "HTTP-01",
|
"tls_letsencrypt_challenge_type": "HTTP-01",
|
||||||
"tls_cert_path": "",
|
"tls_cert_path": "",
|
||||||
"tls_key_path": "",
|
"tls_key_path": "",
|
||||||
"acl_policy_path": ""
|
"acl_policy_path": "",
|
||||||
|
"dns_config": {
|
||||||
|
"nameservers": [
|
||||||
|
"1.1.1.1"
|
||||||
|
]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
5
db.go
5
db.go
@ -44,6 +44,11 @@ func (h *Headscale) initDB() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = db.AutoMigrate(&SharedMachine{})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
err = h.setValue("db_version", dbVersion)
|
err = h.setValue("db_version", dbVersion)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# This file contains some of the official Tailscale DERP servers,
|
# This file contains some of the official Tailscale DERP servers,
|
||||||
# shamelessly taken from https://github.com/tailscale/tailscale/blob/main/net/dnsfallback/dns-fallback-servers.json
|
# shamelessly taken from https://github.com/tailscale/tailscale/blob/main/net/dnsfallback/dns-fallback-servers.json
|
||||||
#
|
#
|
||||||
# If you plan to somehow use headscale, please deploy your own DERP infra
|
# If you plan to somehow use headscale, please deploy your own DERP infra: https://tailscale.com/kb/1118/custom-derp-servers/
|
||||||
regions:
|
regions:
|
||||||
1:
|
1:
|
||||||
regionid: 1
|
regionid: 1
|
||||||
|
@ -4,10 +4,13 @@ package headscale
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -20,23 +23,48 @@ import (
|
|||||||
"inet.af/netaddr"
|
"inet.af/netaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type IntegrationTestSuite struct {
|
|
||||||
suite.Suite
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntegrationTestSuite(t *testing.T) {
|
|
||||||
suite.Run(t, new(IntegrationTestSuite))
|
|
||||||
}
|
|
||||||
|
|
||||||
var integrationTmpDir string
|
var integrationTmpDir string
|
||||||
var ih Headscale
|
var ih Headscale
|
||||||
|
|
||||||
var pool dockertest.Pool
|
var pool dockertest.Pool
|
||||||
var network dockertest.Network
|
var network dockertest.Network
|
||||||
var headscale dockertest.Resource
|
var headscale dockertest.Resource
|
||||||
var tailscaleCount int = 5
|
var tailscaleCount int = 25
|
||||||
var tailscales map[string]dockertest.Resource
|
var tailscales map[string]dockertest.Resource
|
||||||
|
|
||||||
|
type IntegrationTestSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
stats *suite.SuiteInformation
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIntegrationTestSuite(t *testing.T) {
|
||||||
|
s := new(IntegrationTestSuite)
|
||||||
|
suite.Run(t, s)
|
||||||
|
|
||||||
|
// HandleStats, which allows us to check if we passed and save logs
|
||||||
|
// is called after TearDown, so we cannot tear down containers before
|
||||||
|
// we have potentially saved the logs.
|
||||||
|
for _, tailscale := range tailscales {
|
||||||
|
if err := pool.Purge(&tailscale); err != nil {
|
||||||
|
log.Printf("Could not purge resource: %s\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.stats.Passed() {
|
||||||
|
err := saveLog(&headscale, "test_output")
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Could not save log: %s\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := pool.Purge(&headscale); err != nil {
|
||||||
|
log.Printf("Could not purge resource: %s\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := network.Close(); err != nil {
|
||||||
|
log.Printf("Could not close network: %s\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func executeCommand(resource *dockertest.Resource, cmd []string) (string, error) {
|
func executeCommand(resource *dockertest.Resource, cmd []string) (string, error) {
|
||||||
var stdout bytes.Buffer
|
var stdout bytes.Buffer
|
||||||
var stderr bytes.Buffer
|
var stderr bytes.Buffer
|
||||||
@ -62,6 +90,48 @@ func executeCommand(resource *dockertest.Resource, cmd []string) (string, error)
|
|||||||
return stdout.String(), nil
|
return stdout.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func saveLog(resource *dockertest.Resource, basePath string) error {
|
||||||
|
err := os.MkdirAll(basePath, os.ModePerm)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var stdout bytes.Buffer
|
||||||
|
var stderr bytes.Buffer
|
||||||
|
|
||||||
|
err = pool.Client.Logs(
|
||||||
|
docker.LogsOptions{
|
||||||
|
Context: context.TODO(),
|
||||||
|
Container: resource.Container.ID,
|
||||||
|
OutputStream: &stdout,
|
||||||
|
ErrorStream: &stderr,
|
||||||
|
Tail: "all",
|
||||||
|
RawTerminal: false,
|
||||||
|
Stdout: true,
|
||||||
|
Stderr: true,
|
||||||
|
Follow: false,
|
||||||
|
Timestamps: false,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Saving logs for %s to %s\n", resource.Container.Name, basePath)
|
||||||
|
|
||||||
|
err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stdout.log"), []byte(stdout.String()), 0644)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = ioutil.WriteFile(path.Join(basePath, resource.Container.Name+".stderr.log"), []byte(stdout.String()), 0644)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func dockerRestartPolicy(config *docker.HostConfig) {
|
func dockerRestartPolicy(config *docker.HostConfig) {
|
||||||
// set AutoRemove to true so that stopped container goes away by itself
|
// set AutoRemove to true so that stopped container goes away by itself
|
||||||
config.AutoRemove = true
|
config.AutoRemove = true
|
||||||
@ -115,7 +185,6 @@ func (s *IntegrationTestSuite) SetupSuite() {
|
|||||||
PortBindings: map[docker.Port][]docker.PortBinding{
|
PortBindings: map[docker.Port][]docker.PortBinding{
|
||||||
"8080/tcp": []docker.PortBinding{{HostPort: "8080"}},
|
"8080/tcp": []docker.PortBinding{{HostPort: "8080"}},
|
||||||
},
|
},
|
||||||
Env: []string{},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println("Creating headscale container")
|
fmt.Println("Creating headscale container")
|
||||||
@ -134,7 +203,6 @@ func (s *IntegrationTestSuite) SetupSuite() {
|
|||||||
Name: hostname,
|
Name: hostname,
|
||||||
Networks: []*dockertest.Network{&network},
|
Networks: []*dockertest.Network{&network},
|
||||||
Cmd: []string{"tailscaled", "--tun=userspace-networking", "--socks5-server=localhost:1055"},
|
Cmd: []string{"tailscaled", "--tun=userspace-networking", "--socks5-server=localhost:1055"},
|
||||||
Env: []string{},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if pts, err := pool.BuildAndRunWithBuildOptions(tailscaleBuildOptions, tailscaleOptions, dockerRestartPolicy); err == nil {
|
if pts, err := pool.BuildAndRunWithBuildOptions(tailscaleBuildOptions, tailscaleOptions, dockerRestartPolicy); err == nil {
|
||||||
@ -145,7 +213,6 @@ func (s *IntegrationTestSuite) SetupSuite() {
|
|||||||
fmt.Printf("Created %s container\n", hostname)
|
fmt.Printf("Created %s container\n", hostname)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Replace this logic with something that can be detected on Github Actions
|
|
||||||
fmt.Println("Waiting for headscale to be ready")
|
fmt.Println("Waiting for headscale to be ready")
|
||||||
hostEndpoint := fmt.Sprintf("localhost:%s", headscale.GetPort("8080/tcp"))
|
hostEndpoint := fmt.Sprintf("localhost:%s", headscale.GetPort("8080/tcp"))
|
||||||
|
|
||||||
@ -197,23 +264,14 @@ func (s *IntegrationTestSuite) SetupSuite() {
|
|||||||
|
|
||||||
// The nodes need a bit of time to get their updated maps from headscale
|
// The nodes need a bit of time to get their updated maps from headscale
|
||||||
// TODO: See if we can have a more deterministic wait here.
|
// TODO: See if we can have a more deterministic wait here.
|
||||||
time.Sleep(20 * time.Second)
|
time.Sleep(60 * time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *IntegrationTestSuite) TearDownSuite() {
|
func (s *IntegrationTestSuite) TearDownSuite() {
|
||||||
if err := pool.Purge(&headscale); err != nil {
|
|
||||||
log.Printf("Could not purge resource: %s\n", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tailscale := range tailscales {
|
func (s *IntegrationTestSuite) HandleStats(suiteName string, stats *suite.SuiteInformation) {
|
||||||
if err := pool.Purge(&tailscale); err != nil {
|
s.stats = stats
|
||||||
log.Printf("Could not purge resource: %s\n", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := network.Close(); err != nil {
|
|
||||||
log.Printf("Could not close network: %s\n", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *IntegrationTestSuite) TestListNodes() {
|
func (s *IntegrationTestSuite) TestListNodes() {
|
||||||
@ -295,7 +353,15 @@ func (s *IntegrationTestSuite) TestPingAllPeers() {
|
|||||||
s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) {
|
s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) {
|
||||||
// We currently cant ping ourselves, so skip that.
|
// We currently cant ping ourselves, so skip that.
|
||||||
if peername != hostname {
|
if peername != hostname {
|
||||||
command := []string{"tailscale", "ping", "--timeout=1s", "--c=1", ip.String()}
|
// We are only interested in "direct ping" which means what we
|
||||||
|
// might need a couple of more attempts before reaching the node.
|
||||||
|
command := []string{
|
||||||
|
"tailscale", "ping",
|
||||||
|
"--timeout=1s",
|
||||||
|
"--c=20",
|
||||||
|
"--until-direct=true",
|
||||||
|
ip.String(),
|
||||||
|
}
|
||||||
|
|
||||||
fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip)
|
fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip)
|
||||||
result, err := executeCommand(
|
result, err := executeCommand(
|
||||||
|
@ -7,5 +7,5 @@
|
|||||||
"db_type": "sqlite3",
|
"db_type": "sqlite3",
|
||||||
"db_path": "/tmp/integration_test_db.sqlite3",
|
"db_path": "/tmp/integration_test_db.sqlite3",
|
||||||
"acl_policy_path": "",
|
"acl_policy_path": "",
|
||||||
"log_level": "trace"
|
"log_level": "debug"
|
||||||
}
|
}
|
||||||
|
161
machine.go
161
machine.go
@ -2,6 +2,7 @@ package headscale
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -32,6 +33,7 @@ type Machine struct {
|
|||||||
AuthKey *PreAuthKey
|
AuthKey *PreAuthKey
|
||||||
|
|
||||||
LastSeen *time.Time
|
LastSeen *time.Time
|
||||||
|
LastSuccessfulUpdate *time.Time
|
||||||
Expiry *time.Time
|
Expiry *time.Time
|
||||||
|
|
||||||
HostInfo datatypes.JSON
|
HostInfo datatypes.JSON
|
||||||
@ -48,7 +50,9 @@ func (m Machine) isAlreadyRegistered() bool {
|
|||||||
return m.Registered
|
return m.Registered
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Machine) toNode() (*tailcfg.Node, error) {
|
// toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes
|
||||||
|
// as per the expected behaviour in the official SaaS
|
||||||
|
func (m Machine) toNode(includeRoutes bool) (*tailcfg.Node, error) {
|
||||||
nKey, err := wgkey.ParseHex(m.NodeKey)
|
nKey, err := wgkey.ParseHex(m.NodeKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -83,6 +87,7 @@ func (m Machine) toNode() (*tailcfg.Node, error) {
|
|||||||
allowedIPs := []netaddr.IPPrefix{}
|
allowedIPs := []netaddr.IPPrefix{}
|
||||||
allowedIPs = append(allowedIPs, ip) // we append the node own IP, as it is required by the clients
|
allowedIPs = append(allowedIPs, ip) // we append the node own IP, as it is required by the clients
|
||||||
|
|
||||||
|
if includeRoutes {
|
||||||
routesStr := []string{}
|
routesStr := []string{}
|
||||||
if len(m.EnabledRoutes) != 0 {
|
if len(m.EnabledRoutes) != 0 {
|
||||||
allwIps, err := m.EnabledRoutes.MarshalJSON()
|
allwIps, err := m.EnabledRoutes.MarshalJSON()
|
||||||
@ -95,13 +100,14 @@ func (m Machine) toNode() (*tailcfg.Node, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, aip := range routesStr {
|
for _, routeStr := range routesStr {
|
||||||
ip, err := netaddr.ParseIPPrefix(aip)
|
ip, err := netaddr.ParseIPPrefix(routeStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
allowedIPs = append(allowedIPs, ip)
|
allowedIPs = append(allowedIPs, ip)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
endpoints := []string{}
|
endpoints := []string{}
|
||||||
if len(m.Endpoints) != 0 {
|
if len(m.Endpoints) != 0 {
|
||||||
@ -134,13 +140,20 @@ func (m Machine) toNode() (*tailcfg.Node, error) {
|
|||||||
derp = "127.3.3.40:0" // Zero means disconnected or unknown.
|
derp = "127.3.3.40:0" // Zero means disconnected or unknown.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var keyExpiry time.Time
|
||||||
|
if m.Expiry != nil {
|
||||||
|
keyExpiry = *m.Expiry
|
||||||
|
} else {
|
||||||
|
keyExpiry = time.Time{}
|
||||||
|
}
|
||||||
|
|
||||||
n := tailcfg.Node{
|
n := tailcfg.Node{
|
||||||
ID: tailcfg.NodeID(m.ID), // this is the actual ID
|
ID: tailcfg.NodeID(m.ID), // this is the actual ID
|
||||||
StableID: tailcfg.StableNodeID(strconv.FormatUint(m.ID, 10)), // in headscale, unlike tailcontrol server, IDs are permanent
|
StableID: tailcfg.StableNodeID(strconv.FormatUint(m.ID, 10)), // in headscale, unlike tailcontrol server, IDs are permanent
|
||||||
Name: hostinfo.Hostname,
|
Name: hostinfo.Hostname,
|
||||||
User: tailcfg.UserID(m.NamespaceID),
|
User: tailcfg.UserID(m.NamespaceID),
|
||||||
Key: tailcfg.NodeKey(nKey),
|
Key: tailcfg.NodeKey(nKey),
|
||||||
KeyExpiry: *m.Expiry,
|
KeyExpiry: keyExpiry,
|
||||||
Machine: tailcfg.MachineKey(mKey),
|
Machine: tailcfg.MachineKey(mKey),
|
||||||
DiscoKey: discoKey,
|
DiscoKey: discoKey,
|
||||||
Addresses: addrs,
|
Addresses: addrs,
|
||||||
@ -163,6 +176,7 @@ func (h *Headscale) getPeers(m Machine) (*[]*tailcfg.Node, error) {
|
|||||||
Str("func", "getPeers").
|
Str("func", "getPeers").
|
||||||
Str("machine", m.Name).
|
Str("machine", m.Name).
|
||||||
Msg("Finding peers")
|
Msg("Finding peers")
|
||||||
|
|
||||||
machines := []Machine{}
|
machines := []Machine{}
|
||||||
if err := h.db.Where("namespace_id = ? AND machine_key <> ? AND registered",
|
if err := h.db.Where("namespace_id = ? AND machine_key <> ? AND registered",
|
||||||
m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil {
|
m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil {
|
||||||
@ -170,9 +184,23 @@ func (h *Headscale) getPeers(m Machine) (*[]*tailcfg.Node, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We fetch here machines that are shared to the `Namespace` of the machine we are getting peers for
|
||||||
|
sharedMachines := []SharedMachine{}
|
||||||
|
if err := h.db.Preload("Namespace").Preload("Machine").Where("namespace_id = ?",
|
||||||
|
m.NamespaceID).Find(&sharedMachines).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
peers := []*tailcfg.Node{}
|
peers := []*tailcfg.Node{}
|
||||||
for _, mn := range machines {
|
for _, mn := range machines {
|
||||||
peer, err := mn.toNode()
|
peer, err := mn.toNode(true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
peers = append(peers, peer)
|
||||||
|
}
|
||||||
|
for _, sharedMachine := range sharedMachines {
|
||||||
|
peer, err := sharedMachine.Machine.toNode(false) // shared nodes do not expose their routes
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -199,18 +227,27 @@ func (h *Headscale) GetMachine(namespace string, name string) (*Machine, error)
|
|||||||
return &m, nil
|
return &m, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("not found")
|
return nil, fmt.Errorf("machine not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMachineByID finds a Machine by ID and returns the Machine struct
|
// GetMachineByID finds a Machine by ID and returns the Machine struct
|
||||||
func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
|
func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
|
||||||
m := Machine{}
|
m := Machine{}
|
||||||
if result := h.db.Find(&Machine{ID: id}).First(&m); result.Error != nil {
|
if result := h.db.Preload("Namespace").Find(&Machine{ID: id}).First(&m); result.Error != nil {
|
||||||
return nil, result.Error
|
return nil, result.Error
|
||||||
}
|
}
|
||||||
return &m, nil
|
return &m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateMachine takes a Machine struct pointer (typically already loaded from database
|
||||||
|
// and updates it with the latest data from the database.
|
||||||
|
func (h *Headscale) UpdateMachine(m *Machine) error {
|
||||||
|
if result := h.db.Find(m).First(&m); result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteMachine softs deletes a Machine from the database
|
// DeleteMachine softs deletes a Machine from the database
|
||||||
func (h *Headscale) DeleteMachine(m *Machine) error {
|
func (h *Headscale) DeleteMachine(m *Machine) error {
|
||||||
m.Registered = false
|
m.Registered = false
|
||||||
@ -249,23 +286,119 @@ func (m *Machine) GetHostInfo() (*tailcfg.Hostinfo, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) notifyChangesToPeers(m *Machine) {
|
func (h *Headscale) notifyChangesToPeers(m *Machine) {
|
||||||
peers, _ := h.getPeers(*m)
|
peers, err := h.getPeers(*m)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Str("func", "notifyChangesToPeers").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Msgf("Error getting peers: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
for _, p := range *peers {
|
for _, p := range *peers {
|
||||||
pUp, ok := h.clientsPolling.Load(uint64(p.ID))
|
|
||||||
if ok {
|
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("func", "notifyChangesToPeers").
|
Str("func", "notifyChangesToPeers").
|
||||||
Str("machine", m.Name).
|
Str("machine", m.Name).
|
||||||
Str("peer", m.Name).
|
Str("peer", p.Name).
|
||||||
Str("address", p.Addresses[0].String()).
|
Str("address", p.Addresses[0].String()).
|
||||||
Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0])
|
Msgf("Notifying peer %s (%s)", p.Name, p.Addresses[0])
|
||||||
pUp.(chan []byte) <- []byte{}
|
err := h.sendRequestOnUpdateChannel(p)
|
||||||
} else {
|
if err != nil {
|
||||||
log.Info().
|
log.Info().
|
||||||
Str("func", "notifyChangesToPeers").
|
Str("func", "notifyChangesToPeers").
|
||||||
Str("machine", m.Name).
|
Str("machine", m.Name).
|
||||||
Str("peer", m.Name).
|
Str("peer", p.Name).
|
||||||
Msgf("Peer %s does not appear to be polling", p.Name)
|
Msgf("Peer %s does not appear to be polling", p.Name)
|
||||||
}
|
}
|
||||||
|
log.Trace().
|
||||||
|
Str("func", "notifyChangesToPeers").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Str("peer", p.Name).
|
||||||
|
Str("address", p.Addresses[0].String()).
|
||||||
|
Msgf("Notified peer %s (%s)", p.Name, p.Addresses[0])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Headscale) getOrOpenUpdateChannel(m *Machine) <-chan struct{} {
|
||||||
|
var updateChan chan struct{}
|
||||||
|
if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok {
|
||||||
|
if unwrapped, ok := storedChan.(chan struct{}); ok {
|
||||||
|
updateChan = unwrapped
|
||||||
|
} else {
|
||||||
|
log.Error().
|
||||||
|
Str("handler", "openUpdateChannel").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Msg("Failed to convert update channel to struct{}")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Debug().
|
||||||
|
Str("handler", "openUpdateChannel").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Msg("Update channel not found, creating")
|
||||||
|
|
||||||
|
updateChan = make(chan struct{})
|
||||||
|
h.clientsUpdateChannels.Store(m.ID, updateChan)
|
||||||
|
}
|
||||||
|
return updateChan
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Headscale) closeUpdateChannel(m *Machine) {
|
||||||
|
h.clientsUpdateChannelMutex.Lock()
|
||||||
|
defer h.clientsUpdateChannelMutex.Unlock()
|
||||||
|
|
||||||
|
if storedChan, ok := h.clientsUpdateChannels.Load(m.ID); ok {
|
||||||
|
if unwrapped, ok := storedChan.(chan struct{}); ok {
|
||||||
|
close(unwrapped)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.clientsUpdateChannels.Delete(m.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Headscale) sendRequestOnUpdateChannel(m *tailcfg.Node) error {
|
||||||
|
h.clientsUpdateChannelMutex.Lock()
|
||||||
|
defer h.clientsUpdateChannelMutex.Unlock()
|
||||||
|
|
||||||
|
pUp, ok := h.clientsUpdateChannels.Load(uint64(m.ID))
|
||||||
|
if ok {
|
||||||
|
log.Info().
|
||||||
|
Str("func", "requestUpdate").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Msgf("Notifying peer %s", m.Name)
|
||||||
|
|
||||||
|
if update, ok := pUp.(chan struct{}); ok {
|
||||||
|
log.Trace().
|
||||||
|
Str("func", "requestUpdate").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Msgf("Update channel is %#v", update)
|
||||||
|
|
||||||
|
update <- struct{}{}
|
||||||
|
|
||||||
|
log.Trace().
|
||||||
|
Str("func", "requestUpdate").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Msgf("Notified machine %s", m.Name)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Info().
|
||||||
|
Str("func", "requestUpdate").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Msgf("Machine %s does not appear to be polling", m.Name)
|
||||||
|
return errors.New("machine does not seem to be polling")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Headscale) isOutdated(m *Machine) bool {
|
||||||
|
err := h.UpdateMachine(m)
|
||||||
|
if err != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
lastChange := h.getLastStateChange(m.Namespace.Name)
|
||||||
|
log.Trace().
|
||||||
|
Str("func", "keepAlive").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Time("last_successful_update", *m.LastSuccessfulUpdate).
|
||||||
|
Time("last_state_change", lastChange).
|
||||||
|
Msgf("Checking if %s is missing updates", m.Name)
|
||||||
|
return m.LastSuccessfulUpdate.Before(lastChange)
|
||||||
|
}
|
||||||
|
@ -91,12 +91,34 @@ func (h *Headscale) ListMachinesInNamespace(name string) (*[]Machine, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
machines := []Machine{}
|
machines := []Machine{}
|
||||||
if err := h.db.Preload("AuthKey").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil {
|
if err := h.db.Preload("AuthKey").Preload("Namespace").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &machines, nil
|
return &machines, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListSharedMachinesInNamespace returns all the machines that are shared to the specified namespace
|
||||||
|
func (h *Headscale) ListSharedMachinesInNamespace(name string) (*[]Machine, error) {
|
||||||
|
namespace, err := h.GetNamespace(name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sharedMachines := []SharedMachine{}
|
||||||
|
if err := h.db.Preload("Namespace").Where(&SharedMachine{NamespaceID: namespace.ID}).Find(&sharedMachines).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
machines := []Machine{}
|
||||||
|
for _, sharedMachine := range sharedMachines {
|
||||||
|
machine, err := h.GetMachineByID(sharedMachine.MachineID) // otherwise not everything comes filled
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
machines = append(machines, *machine)
|
||||||
|
}
|
||||||
|
return &machines, nil
|
||||||
|
}
|
||||||
|
|
||||||
// SetMachineNamespace assigns a Machine to a namespace
|
// SetMachineNamespace assigns a Machine to a namespace
|
||||||
func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error {
|
func (h *Headscale) SetMachineNamespace(m *Machine, namespaceName string) error {
|
||||||
n, err := h.GetNamespace(namespaceName)
|
n, err := h.GetNamespace(namespaceName)
|
||||||
|
382
poll.go
382
poll.go
@ -1,38 +1,225 @@
|
|||||||
package headscale
|
package headscale
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
"gorm.io/datatypes"
|
||||||
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/wgkey"
|
"tailscale.com/types/wgkey"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// PollNetMapHandler takes care of /machine/:id/map
|
||||||
|
//
|
||||||
|
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
|
||||||
|
// the clients when something in the network changes.
|
||||||
|
//
|
||||||
|
// The clients POST stuff like HostInfo and their Endpoints here, but
|
||||||
|
// only after their first request (marked with the ReadOnly field).
|
||||||
|
//
|
||||||
|
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
|
||||||
|
func (h *Headscale) PollNetMapHandler(c *gin.Context) {
|
||||||
|
log.Trace().
|
||||||
|
Str("handler", "PollNetMap").
|
||||||
|
Str("id", c.Param("id")).
|
||||||
|
Msg("PollNetMapHandler called")
|
||||||
|
body, _ := io.ReadAll(c.Request.Body)
|
||||||
|
mKeyStr := c.Param("id")
|
||||||
|
mKey, err := wgkey.ParseHex(mKeyStr)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Str("handler", "PollNetMap").
|
||||||
|
Err(err).
|
||||||
|
Msg("Cannot parse client key")
|
||||||
|
c.String(http.StatusBadRequest, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
req := tailcfg.MapRequest{}
|
||||||
|
err = decode(body, &req, &mKey, h.privateKey)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Str("handler", "PollNetMap").
|
||||||
|
Err(err).
|
||||||
|
Msg("Cannot decode message")
|
||||||
|
c.String(http.StatusBadRequest, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var m Machine
|
||||||
|
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
log.Warn().
|
||||||
|
Str("handler", "PollNetMap").
|
||||||
|
Msgf("Ignoring request, cannot find machine with key %s", mKey.HexString())
|
||||||
|
c.String(http.StatusUnauthorized, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Trace().
|
||||||
|
Str("handler", "PollNetMap").
|
||||||
|
Str("id", c.Param("id")).
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Msg("Found machine in database")
|
||||||
|
|
||||||
|
hostinfo, _ := json.Marshal(req.Hostinfo)
|
||||||
|
m.Name = req.Hostinfo.Hostname
|
||||||
|
m.HostInfo = datatypes.JSON(hostinfo)
|
||||||
|
m.DiscoKey = wgkey.Key(req.DiscoKey).HexString()
|
||||||
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
// From Tailscale client:
|
||||||
|
//
|
||||||
|
// ReadOnly is whether the client just wants to fetch the MapResponse,
|
||||||
|
// without updating their Endpoints. The Endpoints field will be ignored and
|
||||||
|
// LastSeen will not be updated and peers will not be notified of changes.
|
||||||
|
//
|
||||||
|
// The intended use is for clients to discover the DERP map at start-up
|
||||||
|
// before their first real endpoint update.
|
||||||
|
if !req.ReadOnly {
|
||||||
|
endpoints, _ := json.Marshal(req.Endpoints)
|
||||||
|
m.Endpoints = datatypes.JSON(endpoints)
|
||||||
|
m.LastSeen = &now
|
||||||
|
}
|
||||||
|
h.db.Save(&m)
|
||||||
|
|
||||||
|
data, err := h.getMapResponse(mKey, req, m)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Str("handler", "PollNetMap").
|
||||||
|
Str("id", c.Param("id")).
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Err(err).
|
||||||
|
Msg("Failed to get Map response")
|
||||||
|
c.String(http.StatusInternalServerError, ":(")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// We update our peers if the client is not sending ReadOnly in the MapRequest
|
||||||
|
// so we don't distribute its initial request (it comes with
|
||||||
|
// empty endpoints to peers)
|
||||||
|
|
||||||
|
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
|
||||||
|
log.Debug().
|
||||||
|
Str("handler", "PollNetMap").
|
||||||
|
Str("id", c.Param("id")).
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Bool("readOnly", req.ReadOnly).
|
||||||
|
Bool("omitPeers", req.OmitPeers).
|
||||||
|
Bool("stream", req.Stream).
|
||||||
|
Msg("Client map request processed")
|
||||||
|
|
||||||
|
if req.ReadOnly {
|
||||||
|
log.Info().
|
||||||
|
Str("handler", "PollNetMap").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Msg("Client is starting up. Probably interested in a DERP map")
|
||||||
|
c.Data(200, "application/json; charset=utf-8", *data)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// There has been an update to _any_ of the nodes that the other nodes would
|
||||||
|
// need to know about
|
||||||
|
h.setLastStateChangeToNow(m.Namespace.Name)
|
||||||
|
|
||||||
|
// The request is not ReadOnly, so we need to set up channels for updating
|
||||||
|
// peers via longpoll
|
||||||
|
|
||||||
|
// Only create update channel if it has not been created
|
||||||
|
log.Trace().
|
||||||
|
Str("handler", "PollNetMap").
|
||||||
|
Str("id", c.Param("id")).
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Msg("Loading or creating update channel")
|
||||||
|
updateChan := h.getOrOpenUpdateChannel(&m)
|
||||||
|
|
||||||
|
pollDataChan := make(chan []byte)
|
||||||
|
// defer close(pollData)
|
||||||
|
|
||||||
|
keepAliveChan := make(chan []byte)
|
||||||
|
|
||||||
|
cancelKeepAlive := make(chan struct{})
|
||||||
|
defer close(cancelKeepAlive)
|
||||||
|
|
||||||
|
if req.OmitPeers && !req.Stream {
|
||||||
|
log.Info().
|
||||||
|
Str("handler", "PollNetMap").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Msg("Client sent endpoint update and is ok with a response without peer list")
|
||||||
|
c.Data(200, "application/json; charset=utf-8", *data)
|
||||||
|
|
||||||
|
// It sounds like we should update the nodes when we have received a endpoint update
|
||||||
|
// even tho the comments in the tailscale code dont explicitly say so.
|
||||||
|
go h.notifyChangesToPeers(&m)
|
||||||
|
return
|
||||||
|
} else if req.OmitPeers && req.Stream {
|
||||||
|
log.Warn().
|
||||||
|
Str("handler", "PollNetMap").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Msg("Ignoring request, don't know how to handle it")
|
||||||
|
c.String(http.StatusBadRequest, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().
|
||||||
|
Str("handler", "PollNetMap").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Msg("Client is ready to access the tailnet")
|
||||||
|
log.Info().
|
||||||
|
Str("handler", "PollNetMap").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Msg("Sending initial map")
|
||||||
|
go func() { pollDataChan <- *data }()
|
||||||
|
|
||||||
|
log.Info().
|
||||||
|
Str("handler", "PollNetMap").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Msg("Notifying peers")
|
||||||
|
go h.notifyChangesToPeers(&m)
|
||||||
|
|
||||||
|
h.PollNetMapStream(c, m, req, mKey, pollDataChan, keepAliveChan, updateChan, cancelKeepAlive)
|
||||||
|
log.Trace().
|
||||||
|
Str("handler", "PollNetMap").
|
||||||
|
Str("id", c.Param("id")).
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Msg("Finished stream, closing PollNetMap session")
|
||||||
|
}
|
||||||
|
|
||||||
|
// PollNetMapStream takes care of /machine/:id/map
|
||||||
|
// stream logic, ensuring we communicate updates and data
|
||||||
|
// to the connected clients.
|
||||||
func (h *Headscale) PollNetMapStream(
|
func (h *Headscale) PollNetMapStream(
|
||||||
c *gin.Context,
|
c *gin.Context,
|
||||||
m Machine,
|
m Machine,
|
||||||
req tailcfg.MapRequest,
|
req tailcfg.MapRequest,
|
||||||
mKey wgkey.Key,
|
mKey wgkey.Key,
|
||||||
pollData chan []byte,
|
pollDataChan chan []byte,
|
||||||
update chan []byte,
|
keepAliveChan chan []byte,
|
||||||
cancelKeepAlive chan []byte,
|
updateChan <-chan struct{},
|
||||||
|
cancelKeepAlive chan struct{},
|
||||||
) {
|
) {
|
||||||
|
go h.scheduledPollWorker(cancelKeepAlive, keepAliveChan, mKey, req, m)
|
||||||
go h.keepAlive(cancelKeepAlive, pollData, mKey, req, m)
|
|
||||||
|
|
||||||
c.Stream(func(w io.Writer) bool {
|
c.Stream(func(w io.Writer) bool {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", m.Name).
|
||||||
Msg("Waiting for data to stream...")
|
Msg("Waiting for data to stream...")
|
||||||
select {
|
|
||||||
|
|
||||||
case data := <-pollData:
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", m.Name).
|
||||||
|
Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case data := <-pollDataChan:
|
||||||
|
log.Trace().
|
||||||
|
Str("handler", "PollNetMapStream").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Str("channel", "pollData").
|
||||||
Int("bytes", len(data)).
|
Int("bytes", len(data)).
|
||||||
Msg("Sending data received via pollData channel")
|
Msg("Sending data received via pollData channel")
|
||||||
_, err := w.Write(data)
|
_, err := w.Write(data)
|
||||||
@ -40,34 +227,104 @@ func (h *Headscale) PollNetMapStream(
|
|||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", m.Name).
|
||||||
|
Str("channel", "pollData").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Cannot write data")
|
Msg("Cannot write data")
|
||||||
}
|
}
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", m.Name).
|
||||||
|
Str("channel", "pollData").
|
||||||
Int("bytes", len(data)).
|
Int("bytes", len(data)).
|
||||||
Msg("Data from pollData channel written successfully")
|
Msg("Data from pollData channel written successfully")
|
||||||
|
// TODO: Abstract away all the database calls, this can cause race conditions
|
||||||
|
// when an outdated machine object is kept alive, e.g. db is update from
|
||||||
|
// command line, but then overwritten.
|
||||||
|
err = h.UpdateMachine(&m)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Str("handler", "PollNetMapStream").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Str("channel", "pollData").
|
||||||
|
Err(err).
|
||||||
|
Msg("Cannot update machine from database")
|
||||||
|
}
|
||||||
|
now := time.Now().UTC()
|
||||||
|
m.LastSeen = &now
|
||||||
|
m.LastSuccessfulUpdate = &now
|
||||||
|
h.db.Save(&m)
|
||||||
|
log.Trace().
|
||||||
|
Str("handler", "PollNetMapStream").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Str("channel", "pollData").
|
||||||
|
Int("bytes", len(data)).
|
||||||
|
Msg("Machine updated successfully after sending pollData")
|
||||||
|
return true
|
||||||
|
|
||||||
|
case data := <-keepAliveChan:
|
||||||
|
log.Trace().
|
||||||
|
Str("handler", "PollNetMapStream").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Str("channel", "keepAlive").
|
||||||
|
Int("bytes", len(data)).
|
||||||
|
Msg("Sending keep alive message")
|
||||||
|
_, err := w.Write(data)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Str("handler", "PollNetMapStream").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Str("channel", "keepAlive").
|
||||||
|
Err(err).
|
||||||
|
Msg("Cannot write keep alive message")
|
||||||
|
}
|
||||||
|
log.Trace().
|
||||||
|
Str("handler", "PollNetMapStream").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Str("channel", "keepAlive").
|
||||||
|
Int("bytes", len(data)).
|
||||||
|
Msg("Keep alive sent successfully")
|
||||||
|
// TODO: Abstract away all the database calls, this can cause race conditions
|
||||||
|
// when an outdated machine object is kept alive, e.g. db is update from
|
||||||
|
// command line, but then overwritten.
|
||||||
|
err = h.UpdateMachine(&m)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Str("handler", "PollNetMapStream").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Str("channel", "keepAlive").
|
||||||
|
Err(err).
|
||||||
|
Msg("Cannot update machine from database")
|
||||||
|
}
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
m.LastSeen = &now
|
m.LastSeen = &now
|
||||||
h.db.Save(&m)
|
h.db.Save(&m)
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", m.Name).
|
||||||
|
Str("channel", "keepAlive").
|
||||||
Int("bytes", len(data)).
|
Int("bytes", len(data)).
|
||||||
Msg("Machine updated successfully after sending pollData")
|
Msg("Machine updated successfully after sending keep alive")
|
||||||
return true
|
return true
|
||||||
|
|
||||||
case <-update:
|
case <-updateChan:
|
||||||
|
log.Trace().
|
||||||
|
Str("handler", "PollNetMapStream").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Str("channel", "update").
|
||||||
|
Msg("Received a request for update")
|
||||||
|
if h.isOutdated(&m) {
|
||||||
log.Debug().
|
log.Debug().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", m.Name).
|
||||||
Msg("Received a request for update")
|
Time("last_successful_update", *m.LastSuccessfulUpdate).
|
||||||
|
Time("last_state_change", h.getLastStateChange(m.Namespace.Name)).
|
||||||
|
Msgf("There has been updates since the last successful update to %s", m.Name)
|
||||||
data, err := h.getMapResponse(mKey, req, m)
|
data, err := h.getMapResponse(mKey, req, m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", m.Name).
|
||||||
|
Str("channel", "update").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Could not get the map update")
|
Msg("Could not get the map update")
|
||||||
}
|
}
|
||||||
@ -76,9 +333,43 @@ func (h *Headscale) PollNetMapStream(
|
|||||||
log.Error().
|
log.Error().
|
||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", m.Name).
|
||||||
|
Str("channel", "update").
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("Could not write the map response")
|
Msg("Could not write the map response")
|
||||||
}
|
}
|
||||||
|
log.Trace().
|
||||||
|
Str("handler", "PollNetMapStream").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Str("channel", "update").
|
||||||
|
Msg("Updated Map has been sent")
|
||||||
|
|
||||||
|
// Keep track of the last successful update,
|
||||||
|
// we sometimes end in a state were the update
|
||||||
|
// is not picked up by a client and we use this
|
||||||
|
// to determine if we should "force" an update.
|
||||||
|
// TODO: Abstract away all the database calls, this can cause race conditions
|
||||||
|
// when an outdated machine object is kept alive, e.g. db is update from
|
||||||
|
// command line, but then overwritten.
|
||||||
|
err = h.UpdateMachine(&m)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Str("handler", "PollNetMapStream").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Str("channel", "update").
|
||||||
|
Err(err).
|
||||||
|
Msg("Cannot update machine from database")
|
||||||
|
}
|
||||||
|
now := time.Now().UTC()
|
||||||
|
m.LastSuccessfulUpdate = &now
|
||||||
|
h.db.Save(&m)
|
||||||
|
} else {
|
||||||
|
log.Trace().
|
||||||
|
Str("handler", "PollNetMapStream").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Time("last_successful_update", *m.LastSuccessfulUpdate).
|
||||||
|
Time("last_state_change", h.getLastStateChange(m.Namespace.Name)).
|
||||||
|
Msgf("%s is up to date", m.Name)
|
||||||
|
}
|
||||||
return true
|
return true
|
||||||
|
|
||||||
case <-c.Request.Context().Done():
|
case <-c.Request.Context().Done():
|
||||||
@ -86,13 +377,78 @@ func (h *Headscale) PollNetMapStream(
|
|||||||
Str("handler", "PollNetMapStream").
|
Str("handler", "PollNetMapStream").
|
||||||
Str("machine", m.Name).
|
Str("machine", m.Name).
|
||||||
Msg("The client has closed the connection")
|
Msg("The client has closed the connection")
|
||||||
|
// TODO: Abstract away all the database calls, this can cause race conditions
|
||||||
|
// when an outdated machine object is kept alive, e.g. db is update from
|
||||||
|
// command line, but then overwritten.
|
||||||
|
err := h.UpdateMachine(&m)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Str("handler", "PollNetMapStream").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Str("channel", "Done").
|
||||||
|
Err(err).
|
||||||
|
Msg("Cannot update machine from database")
|
||||||
|
}
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
m.LastSeen = &now
|
m.LastSeen = &now
|
||||||
h.db.Save(&m)
|
h.db.Save(&m)
|
||||||
cancelKeepAlive <- []byte{}
|
|
||||||
h.clientsPolling.Delete(m.ID)
|
cancelKeepAlive <- struct{}{}
|
||||||
close(update)
|
|
||||||
|
h.closeUpdateChannel(&m)
|
||||||
|
|
||||||
|
close(pollDataChan)
|
||||||
|
|
||||||
|
close(keepAliveChan)
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Headscale) scheduledPollWorker(
|
||||||
|
cancelChan <-chan struct{},
|
||||||
|
keepAliveChan chan<- []byte,
|
||||||
|
mKey wgkey.Key,
|
||||||
|
req tailcfg.MapRequest,
|
||||||
|
m Machine,
|
||||||
|
) {
|
||||||
|
keepAliveTicker := time.NewTicker(60 * time.Second)
|
||||||
|
updateCheckerTicker := time.NewTicker(30 * time.Second)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-cancelChan:
|
||||||
|
return
|
||||||
|
|
||||||
|
case <-keepAliveTicker.C:
|
||||||
|
data, err := h.getMapKeepAliveResponse(mKey, req, m)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Str("func", "keepAlive").
|
||||||
|
Err(err).
|
||||||
|
Msg("Error generating the keep alive msg")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().
|
||||||
|
Str("func", "keepAlive").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Msg("Sending keepalive")
|
||||||
|
keepAliveChan <- *data
|
||||||
|
|
||||||
|
case <-updateCheckerTicker.C:
|
||||||
|
// Send an update request regardless of outdated or not, if data is sent
|
||||||
|
// to the node is determined in the updateChan consumer block
|
||||||
|
n, _ := m.toNode(true)
|
||||||
|
err := h.sendRequestOnUpdateChannel(n)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().
|
||||||
|
Str("func", "keepAlive").
|
||||||
|
Str("machine", m.Name).
|
||||||
|
Err(err).
|
||||||
|
Msgf("Failed to send update request to %s", m.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
115
routes.go
115
routes.go
@ -2,55 +2,142 @@ package headscale
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/pterm/pterm"
|
||||||
"gorm.io/datatypes"
|
"gorm.io/datatypes"
|
||||||
"inet.af/netaddr"
|
"inet.af/netaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetNodeRoutes returns the subnet routes advertised by a node (identified by
|
// GetAdvertisedNodeRoutes returns the subnet routes advertised by a node (identified by
|
||||||
// namespace and node name)
|
// namespace and node name)
|
||||||
func (h *Headscale) GetNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) {
|
func (h *Headscale) GetAdvertisedNodeRoutes(namespace string, nodeName string) (*[]netaddr.IPPrefix, error) {
|
||||||
m, err := h.GetMachine(namespace, nodeName)
|
m, err := h.GetMachine(namespace, nodeName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
hi, err := m.GetHostInfo()
|
hostInfo, err := m.GetHostInfo()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &hi.RoutableIPs, nil
|
return &hostInfo.RoutableIPs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// EnableNodeRoute enables a subnet route advertised by a node (identified by
|
// GetEnabledNodeRoutes returns the subnet routes enabled by a node (identified by
|
||||||
// namespace and node name)
|
// namespace and node name)
|
||||||
func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) (*netaddr.IPPrefix, error) {
|
func (h *Headscale) GetEnabledNodeRoutes(namespace string, nodeName string) ([]netaddr.IPPrefix, error) {
|
||||||
m, err := h.GetMachine(namespace, nodeName)
|
m, err := h.GetMachine(namespace, nodeName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
hi, err := m.GetHostInfo()
|
|
||||||
|
data, err := m.EnabledRoutes.MarshalJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
routesStr := []string{}
|
||||||
|
err = json.Unmarshal(data, &routesStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
routes := make([]netaddr.IPPrefix, len(routesStr))
|
||||||
|
for index, routeStr := range routesStr {
|
||||||
route, err := netaddr.ParseIPPrefix(routeStr)
|
route, err := netaddr.ParseIPPrefix(routeStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
routes[index] = route
|
||||||
|
}
|
||||||
|
|
||||||
|
return routes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNodeRouteEnabled checks if a certain route has been enabled
|
||||||
|
func (h *Headscale) IsNodeRouteEnabled(namespace string, nodeName string, routeStr string) bool {
|
||||||
|
route, err := netaddr.ParseIPPrefix(routeStr)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
enabledRoutes, err := h.GetEnabledNodeRoutes(namespace, nodeName)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, enabledRoute := range enabledRoutes {
|
||||||
|
if route == enabledRoute {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnableNodeRoute enables a subnet route advertised by a node (identified by
|
||||||
|
// namespace and node name)
|
||||||
|
func (h *Headscale) EnableNodeRoute(namespace string, nodeName string, routeStr string) error {
|
||||||
|
m, err := h.GetMachine(namespace, nodeName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
route, err := netaddr.ParseIPPrefix(routeStr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
availableRoutes, err := h.GetAdvertisedNodeRoutes(namespace, nodeName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
enabledRoutes, err := h.GetEnabledNodeRoutes(namespace, nodeName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
available := false
|
||||||
|
for _, availableRoute := range *availableRoutes {
|
||||||
|
// If the route is available, and not yet enabled, add it to the new routing table
|
||||||
|
if route == availableRoute {
|
||||||
|
available = true
|
||||||
|
if !h.IsNodeRouteEnabled(namespace, nodeName, routeStr) {
|
||||||
|
enabledRoutes = append(enabledRoutes, route)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !available {
|
||||||
|
return fmt.Errorf("route (%s) is not available on node %s", nodeName, routeStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
routes, err := json.Marshal(enabledRoutes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
for _, rIP := range hi.RoutableIPs {
|
|
||||||
if rIP == route {
|
|
||||||
routes, _ := json.Marshal([]string{routeStr}) // TODO: only one for the time being, so overwriting the rest
|
|
||||||
m.EnabledRoutes = datatypes.JSON(routes)
|
m.EnabledRoutes = datatypes.JSON(routes)
|
||||||
h.db.Save(&m)
|
h.db.Save(&m)
|
||||||
|
|
||||||
err = h.RequestMapUpdates(m.NamespaceID)
|
err = h.RequestMapUpdates(m.NamespaceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
return &rIP, nil
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RoutesToPtables converts the list of routes to a nice table
|
||||||
|
func (h *Headscale) RoutesToPtables(namespace string, nodeName string, availableRoutes []netaddr.IPPrefix) pterm.TableData {
|
||||||
|
d := pterm.TableData{{"Route", "Enabled"}}
|
||||||
|
|
||||||
|
for _, route := range availableRoutes {
|
||||||
|
enabled := h.IsNodeRouteEnabled(namespace, nodeName, route.String())
|
||||||
|
|
||||||
|
d = append(d, []string{route.String(), strconv.FormatBool(enabled)})
|
||||||
}
|
}
|
||||||
return nil, errors.New("could not find routable range")
|
return d
|
||||||
}
|
}
|
||||||
|
@ -16,7 +16,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
|
|||||||
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = h.GetMachine("test", "testmachine")
|
_, err = h.GetMachine("test", "test_get_route_machine")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
route, err := netaddr.ParseIPPrefix("10.0.0.0/24")
|
route, err := netaddr.ParseIPPrefix("10.0.0.0/24")
|
||||||
@ -33,7 +33,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
|
|||||||
MachineKey: "foo",
|
MachineKey: "foo",
|
||||||
NodeKey: "bar",
|
NodeKey: "bar",
|
||||||
DiscoKey: "faa",
|
DiscoKey: "faa",
|
||||||
Name: "testmachine",
|
Name: "test_get_route_machine",
|
||||||
NamespaceID: n.ID,
|
NamespaceID: n.ID,
|
||||||
Registered: true,
|
Registered: true,
|
||||||
RegisterMethod: "authKey",
|
RegisterMethod: "authKey",
|
||||||
@ -42,14 +42,87 @@ func (s *Suite) TestGetRoutes(c *check.C) {
|
|||||||
}
|
}
|
||||||
h.db.Save(&m)
|
h.db.Save(&m)
|
||||||
|
|
||||||
r, err := h.GetNodeRoutes("test", "testmachine")
|
r, err := h.GetAdvertisedNodeRoutes("test", "test_get_route_machine")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
c.Assert(len(*r), check.Equals, 1)
|
c.Assert(len(*r), check.Equals, 1)
|
||||||
|
|
||||||
_, err = h.EnableNodeRoute("test", "testmachine", "192.168.0.0/24")
|
err = h.EnableNodeRoute("test", "test_get_route_machine", "192.168.0.0/24")
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
_, err = h.EnableNodeRoute("test", "testmachine", "10.0.0.0/24")
|
err = h.EnableNodeRoute("test", "test_get_route_machine", "10.0.0.0/24")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Suite) TestGetEnableRoutes(c *check.C) {
|
||||||
|
n, err := h.CreateNamespace("test")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = h.GetMachine("test", "test_enable_route_machine")
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
|
route, err := netaddr.ParseIPPrefix(
|
||||||
|
"10.0.0.0/24",
|
||||||
|
)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
route2, err := netaddr.ParseIPPrefix(
|
||||||
|
"150.0.10.0/25",
|
||||||
|
)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
hi := tailcfg.Hostinfo{
|
||||||
|
RoutableIPs: []netaddr.IPPrefix{route, route2},
|
||||||
|
}
|
||||||
|
hostinfo, err := json.Marshal(hi)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
m := Machine{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: "foo",
|
||||||
|
NodeKey: "bar",
|
||||||
|
DiscoKey: "faa",
|
||||||
|
Name: "test_enable_route_machine",
|
||||||
|
NamespaceID: n.ID,
|
||||||
|
Registered: true,
|
||||||
|
RegisterMethod: "authKey",
|
||||||
|
AuthKeyID: uint(pak.ID),
|
||||||
|
HostInfo: datatypes.JSON(hostinfo),
|
||||||
|
}
|
||||||
|
h.db.Save(&m)
|
||||||
|
|
||||||
|
availableRoutes, err := h.GetAdvertisedNodeRoutes("test", "test_enable_route_machine")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(len(*availableRoutes), check.Equals, 2)
|
||||||
|
|
||||||
|
enabledRoutes, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(len(enabledRoutes), check.Equals, 0)
|
||||||
|
|
||||||
|
err = h.EnableNodeRoute("test", "test_enable_route_machine", "192.168.0.0/24")
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
|
err = h.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
enabledRoutes1, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(len(enabledRoutes1), check.Equals, 1)
|
||||||
|
|
||||||
|
// Adding it twice will just let it pass through
|
||||||
|
err = h.EnableNodeRoute("test", "test_enable_route_machine", "10.0.0.0/24")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
enabledRoutes2, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(len(enabledRoutes2), check.Equals, 1)
|
||||||
|
|
||||||
|
err = h.EnableNodeRoute("test", "test_enable_route_machine", "150.0.10.0/25")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
enabledRoutes3, err := h.GetEnabledNodeRoutes("test", "test_enable_route_machine")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(len(enabledRoutes3), check.Equals, 2)
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
#!/bin/bash
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
set -e -o pipefail
|
set -e -o pipefail
|
||||||
commit="$1"
|
commit="$1"
|
||||||
|
37
sharing.go
Normal file
37
sharing.go
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
package headscale
|
||||||
|
|
||||||
|
import "gorm.io/gorm"
|
||||||
|
|
||||||
|
const errorSameNamespace = Error("Destination namespace same as origin")
|
||||||
|
const errorMachineAlreadyShared = Error("Node already shared to this namespace")
|
||||||
|
|
||||||
|
// SharedMachine is a join table to support sharing nodes between namespaces
|
||||||
|
type SharedMachine struct {
|
||||||
|
gorm.Model
|
||||||
|
MachineID uint64
|
||||||
|
Machine Machine
|
||||||
|
NamespaceID uint
|
||||||
|
Namespace Namespace
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSharedMachineToNamespace adds a machine as a shared node to a namespace
|
||||||
|
func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error {
|
||||||
|
if m.NamespaceID == ns.ID {
|
||||||
|
return errorSameNamespace
|
||||||
|
}
|
||||||
|
|
||||||
|
sharedMachine := SharedMachine{}
|
||||||
|
if err := h.db.Where("machine_id = ? AND namespace_id", m.ID, ns.ID).First(&sharedMachine).Error; err == nil {
|
||||||
|
return errorMachineAlreadyShared
|
||||||
|
}
|
||||||
|
|
||||||
|
sharedMachine = SharedMachine{
|
||||||
|
MachineID: m.ID,
|
||||||
|
Machine: *m,
|
||||||
|
NamespaceID: ns.ID,
|
||||||
|
Namespace: *ns,
|
||||||
|
}
|
||||||
|
h.db.Save(&sharedMachine)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
359
sharing_test.go
Normal file
359
sharing_test.go
Normal file
@ -0,0 +1,359 @@
|
|||||||
|
package headscale
|
||||||
|
|
||||||
|
import (
|
||||||
|
"gopkg.in/check.v1"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Suite) TestBasicSharedNodesInNamespace(c *check.C) {
|
||||||
|
n1, err := h.CreateNamespace("shared1")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
n2, err := h.CreateNamespace("shared2")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
|
m1 := Machine{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
|
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
|
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
|
Name: "test_get_shared_nodes_1",
|
||||||
|
NamespaceID: n1.ID,
|
||||||
|
Registered: true,
|
||||||
|
RegisterMethod: "authKey",
|
||||||
|
IPAddress: "100.64.0.1",
|
||||||
|
AuthKeyID: uint(pak1.ID),
|
||||||
|
}
|
||||||
|
h.db.Save(&m1)
|
||||||
|
|
||||||
|
_, err = h.GetMachine(n1.Name, m1.Name)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
m2 := Machine{
|
||||||
|
ID: 1,
|
||||||
|
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
Name: "test_get_shared_nodes_2",
|
||||||
|
NamespaceID: n2.ID,
|
||||||
|
Registered: true,
|
||||||
|
RegisterMethod: "authKey",
|
||||||
|
IPAddress: "100.64.0.2",
|
||||||
|
AuthKeyID: uint(pak2.ID),
|
||||||
|
}
|
||||||
|
h.db.Save(&m2)
|
||||||
|
|
||||||
|
_, err = h.GetMachine(n2.Name, m2.Name)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
p1s, err := h.getPeers(m1)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(len(*p1s), check.Equals, 0)
|
||||||
|
|
||||||
|
err = h.AddSharedMachineToNamespace(&m2, n1)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
p1sAfter, err := h.getPeers(m1)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(len(*p1sAfter), check.Equals, 1)
|
||||||
|
c.Assert((*p1sAfter)[0].ID, check.Equals, tailcfg.NodeID(m2.ID))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Suite) TestSameNamespace(c *check.C) {
|
||||||
|
n1, err := h.CreateNamespace("shared1")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
n2, err := h.CreateNamespace("shared2")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
|
m1 := Machine{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
|
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
|
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
|
Name: "test_get_shared_nodes_1",
|
||||||
|
NamespaceID: n1.ID,
|
||||||
|
Registered: true,
|
||||||
|
RegisterMethod: "authKey",
|
||||||
|
IPAddress: "100.64.0.1",
|
||||||
|
AuthKeyID: uint(pak1.ID),
|
||||||
|
}
|
||||||
|
h.db.Save(&m1)
|
||||||
|
|
||||||
|
_, err = h.GetMachine(n1.Name, m1.Name)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
m2 := Machine{
|
||||||
|
ID: 1,
|
||||||
|
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
Name: "test_get_shared_nodes_2",
|
||||||
|
NamespaceID: n2.ID,
|
||||||
|
Registered: true,
|
||||||
|
RegisterMethod: "authKey",
|
||||||
|
IPAddress: "100.64.0.2",
|
||||||
|
AuthKeyID: uint(pak2.ID),
|
||||||
|
}
|
||||||
|
h.db.Save(&m2)
|
||||||
|
|
||||||
|
_, err = h.GetMachine(n2.Name, m2.Name)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
p1s, err := h.getPeers(m1)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(len(*p1s), check.Equals, 0)
|
||||||
|
|
||||||
|
err = h.AddSharedMachineToNamespace(&m1, n1)
|
||||||
|
c.Assert(err, check.Equals, errorSameNamespace)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Suite) TestAlreadyShared(c *check.C) {
|
||||||
|
n1, err := h.CreateNamespace("shared1")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
n2, err := h.CreateNamespace("shared2")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
|
m1 := Machine{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
|
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
|
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
|
Name: "test_get_shared_nodes_1",
|
||||||
|
NamespaceID: n1.ID,
|
||||||
|
Registered: true,
|
||||||
|
RegisterMethod: "authKey",
|
||||||
|
IPAddress: "100.64.0.1",
|
||||||
|
AuthKeyID: uint(pak1.ID),
|
||||||
|
}
|
||||||
|
h.db.Save(&m1)
|
||||||
|
|
||||||
|
_, err = h.GetMachine(n1.Name, m1.Name)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
m2 := Machine{
|
||||||
|
ID: 1,
|
||||||
|
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
Name: "test_get_shared_nodes_2",
|
||||||
|
NamespaceID: n2.ID,
|
||||||
|
Registered: true,
|
||||||
|
RegisterMethod: "authKey",
|
||||||
|
IPAddress: "100.64.0.2",
|
||||||
|
AuthKeyID: uint(pak2.ID),
|
||||||
|
}
|
||||||
|
h.db.Save(&m2)
|
||||||
|
|
||||||
|
_, err = h.GetMachine(n2.Name, m2.Name)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
p1s, err := h.getPeers(m1)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(len(*p1s), check.Equals, 0)
|
||||||
|
|
||||||
|
err = h.AddSharedMachineToNamespace(&m2, n1)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
err = h.AddSharedMachineToNamespace(&m2, n1)
|
||||||
|
c.Assert(err, check.Equals, errorMachineAlreadyShared)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Suite) TestDoNotIncludeRoutesOnShared(c *check.C) {
|
||||||
|
n1, err := h.CreateNamespace("shared1")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
n2, err := h.CreateNamespace("shared2")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
|
m1 := Machine{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
|
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
|
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
|
Name: "test_get_shared_nodes_1",
|
||||||
|
NamespaceID: n1.ID,
|
||||||
|
Registered: true,
|
||||||
|
RegisterMethod: "authKey",
|
||||||
|
IPAddress: "100.64.0.1",
|
||||||
|
AuthKeyID: uint(pak1.ID),
|
||||||
|
}
|
||||||
|
h.db.Save(&m1)
|
||||||
|
|
||||||
|
_, err = h.GetMachine(n1.Name, m1.Name)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
m2 := Machine{
|
||||||
|
ID: 1,
|
||||||
|
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
Name: "test_get_shared_nodes_2",
|
||||||
|
NamespaceID: n2.ID,
|
||||||
|
Registered: true,
|
||||||
|
RegisterMethod: "authKey",
|
||||||
|
IPAddress: "100.64.0.2",
|
||||||
|
AuthKeyID: uint(pak2.ID),
|
||||||
|
}
|
||||||
|
h.db.Save(&m2)
|
||||||
|
|
||||||
|
_, err = h.GetMachine(n2.Name, m2.Name)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
p1s, err := h.getPeers(m1)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(len(*p1s), check.Equals, 0)
|
||||||
|
|
||||||
|
err = h.AddSharedMachineToNamespace(&m2, n1)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
p1sAfter, err := h.getPeers(m1)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(len(*p1sAfter), check.Equals, 1)
|
||||||
|
c.Assert(len((*p1sAfter)[0].AllowedIPs), check.Equals, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) {
|
||||||
|
n1, err := h.CreateNamespace("shared1")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
n2, err := h.CreateNamespace("shared2")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
n3, err := h.CreateNamespace("shared3")
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak3, err := h.CreatePreAuthKey(n3.Name, false, false, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
pak4, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
|
||||||
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
|
m1 := Machine{
|
||||||
|
ID: 0,
|
||||||
|
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
|
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
|
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
|
||||||
|
Name: "test_get_shared_nodes_1",
|
||||||
|
NamespaceID: n1.ID,
|
||||||
|
Registered: true,
|
||||||
|
RegisterMethod: "authKey",
|
||||||
|
IPAddress: "100.64.0.1",
|
||||||
|
AuthKeyID: uint(pak1.ID),
|
||||||
|
}
|
||||||
|
h.db.Save(&m1)
|
||||||
|
|
||||||
|
_, err = h.GetMachine(n1.Name, m1.Name)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
m2 := Machine{
|
||||||
|
ID: 1,
|
||||||
|
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
Name: "test_get_shared_nodes_2",
|
||||||
|
NamespaceID: n2.ID,
|
||||||
|
Registered: true,
|
||||||
|
RegisterMethod: "authKey",
|
||||||
|
IPAddress: "100.64.0.2",
|
||||||
|
AuthKeyID: uint(pak2.ID),
|
||||||
|
}
|
||||||
|
h.db.Save(&m2)
|
||||||
|
|
||||||
|
_, err = h.GetMachine(n2.Name, m2.Name)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
m3 := Machine{
|
||||||
|
ID: 2,
|
||||||
|
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
Name: "test_get_shared_nodes_3",
|
||||||
|
NamespaceID: n3.ID,
|
||||||
|
Registered: true,
|
||||||
|
RegisterMethod: "authKey",
|
||||||
|
IPAddress: "100.64.0.3",
|
||||||
|
AuthKeyID: uint(pak3.ID),
|
||||||
|
}
|
||||||
|
h.db.Save(&m3)
|
||||||
|
|
||||||
|
_, err = h.GetMachine(n3.Name, m3.Name)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
m4 := Machine{
|
||||||
|
ID: 3,
|
||||||
|
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
|
||||||
|
Name: "test_get_shared_nodes_4",
|
||||||
|
NamespaceID: n1.ID,
|
||||||
|
Registered: true,
|
||||||
|
RegisterMethod: "authKey",
|
||||||
|
IPAddress: "100.64.0.4",
|
||||||
|
AuthKeyID: uint(pak4.ID),
|
||||||
|
}
|
||||||
|
h.db.Save(&m4)
|
||||||
|
|
||||||
|
_, err = h.GetMachine(n1.Name, m4.Name)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
p1s, err := h.getPeers(m1)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(len(*p1s), check.Equals, 1) // nodes 1 and 4
|
||||||
|
|
||||||
|
err = h.AddSharedMachineToNamespace(&m2, n1)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
|
p1sAfter, err := h.getPeers(m1)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(len(*p1sAfter), check.Equals, 2) // nodes 1, 2, 4
|
||||||
|
|
||||||
|
pAlone, err := h.getPeers(m3)
|
||||||
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(len(*pAlone), check.Equals, 0) // node 3 is alone
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user