Compare commits

...

7 Commits

Author SHA1 Message Date
Juan Font Alonso
96b02f7d89 Updated test config to work with TS2021 2022-06-16 00:27:53 +02:00
Juan Font Alonso
7078d36dc6 Added MapPoll to Noise protocol 2022-06-16 00:21:46 +02:00
Juan Font Alonso
670c7d9144 TS2021: Add Noise endpoint for node registration 2022-06-12 14:33:47 +02:00
Juan Font Alonso
e8205e8d5a TS2021: Use NodeKey for everything, as MachineKey is deprecated in TS2021 2022-06-12 12:30:56 +02:00
Juan Font Alonso
b40b4e8d45 Added Noise upgrade handler and Noise mux 2022-06-11 19:08:35 +02:00
Juan Font Alonso
304987b4ff TS2021: Convert /key handler to send the Noise key too 2022-06-11 19:00:49 +02:00
Juan Font Alonso
c908627e68 Generate and read the Noise private key 2022-06-11 18:53:11 +02:00
14 changed files with 1415 additions and 40 deletions

34
api.go
View File

@@ -9,6 +9,7 @@ import (
"html/template" "html/template"
"io" "io"
"net/http" "net/http"
"strconv"
"strings" "strings"
"time" "time"
@@ -28,11 +29,40 @@ const (
ErrRegisterMethodCLIDoesNotSupportExpire = Error( ErrRegisterMethodCLIDoesNotSupportExpire = Error(
"machines registered with CLI does not support expire", "machines registered with CLI does not support expire",
) )
// The CapabilityVersion is used by Tailscale clients to indicate
// their codebase version. Tailscale clients can communicate over TS2021
// from CapabilityVersion 28.
// See https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go
NoiseCapabilityVersion = 28
) )
// KeyHandler provides the Headscale pub key // KeyHandler provides the Headscale pub key
// Listens in /key. // Listens in /key.
func (h *Headscale) KeyHandler(ctx *gin.Context) { func (h *Headscale) KeyHandler(ctx *gin.Context) {
// New Tailscale clients send a 'v' parameter to indicate the CurrentCapabilityVersion
clientCapabilityStr := ctx.Query("v")
if clientCapabilityStr != "" {
clientCapabilityVersion, err := strconv.Atoi(clientCapabilityStr)
if err != nil {
ctx.String(http.StatusBadRequest, "Invalid version")
return
}
if clientCapabilityVersion >= NoiseCapabilityVersion {
// Tailscale has a different key for the TS2021 protocol
resp := tailcfg.OverTLSPublicKeyResponse{
LegacyPublicKey: h.privateKey.Public(),
PublicKey: h.noisePrivateKey.Public(),
}
ctx.JSON(http.StatusOK, resp)
return
}
}
// Old clients don't send a 'v' parameter, so we send the legacy public key
ctx.Data( ctx.Data(
http.StatusOK, http.StatusOK,
"text/plain; charset=utf-8", "text/plain; charset=utf-8",
@@ -516,11 +546,11 @@ func (h *Headscale) handleMachineRegistrationNew(
resp.AuthURL = fmt.Sprintf( resp.AuthURL = fmt.Sprintf(
"%s/oidc/register/%s", "%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), strings.TrimSuffix(h.cfg.ServerURL, "/"),
machineKey.String(), NodePublicKeyStripPrefix(registerRequest.NodeKey),
) )
} else { } else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s", resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey)) strings.TrimSuffix(h.cfg.ServerURL, "/"), NodePublicKeyStripPrefix(registerRequest.NodeKey))
} }
respBody, err := encode(resp, &machineKey, h.privateKey) respBody, err := encode(resp, &machineKey, h.privateKey)

51
app.go
View File

@@ -71,12 +71,15 @@ const (
// Headscale represents the base app of the service. // Headscale represents the base app of the service.
type Headscale struct { type Headscale struct {
cfg *Config cfg *Config
db *gorm.DB db *gorm.DB
dbString string dbString string
dbType string dbType string
dbDebug bool dbDebug bool
privateKey *key.MachinePrivate privateKey *key.MachinePrivate
noisePrivateKey *key.MachinePrivate
noiseMux *http.ServeMux
DERPMap *tailcfg.DERPMap DERPMap *tailcfg.DERPMap
DERPServer *DERPServer DERPServer *DERPServer
@@ -116,11 +119,20 @@ func LookupTLSClientAuthMode(mode string) (tls.ClientAuthType, bool) {
} }
func NewHeadscale(cfg *Config) (*Headscale, error) { func NewHeadscale(cfg *Config) (*Headscale, error) {
privKey, err := readOrCreatePrivateKey(cfg.PrivateKeyPath) privateKey, err := readOrCreatePrivateKey(cfg.PrivateKeyPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to read or create private key: %w", err) return nil, fmt.Errorf("failed to read or create private key: %w", err)
} }
noisePrivateKey, err := readOrCreatePrivateKey(cfg.NoisePrivateKeyPath)
if err != nil {
return nil, fmt.Errorf("failed to read or create noise private key: %w", err)
}
if privateKey.Equal(*noisePrivateKey) {
return nil, fmt.Errorf("private key and noise private key are the same")
}
var dbString string var dbString string
switch cfg.DBtype { switch cfg.DBtype {
case Postgres: case Postgres:
@@ -147,7 +159,8 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
cfg: cfg, cfg: cfg,
dbType: cfg.DBtype, dbType: cfg.DBtype,
dbString: dbString, dbString: dbString,
privateKey: privKey, privateKey: privateKey,
noisePrivateKey: noisePrivateKey,
aclRules: tailcfg.FilterAllowAll, // default allowall aclRules: tailcfg.FilterAllowAll, // default allowall
registrationCache: registrationCache, registrationCache: registrationCache,
} }
@@ -393,6 +406,7 @@ func (h *Headscale) createPrometheusRouter() *gin.Engine {
func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *gin.Engine { func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *gin.Engine {
router := gin.Default() router := gin.Default()
router.POST(ts2021UpgradePath, h.NoiseUpgradeHandler)
router.GET( router.GET(
"/health", "/health",
func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) }, func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) },
@@ -401,7 +415,7 @@ func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *gin.Engine {
router.GET("/register", h.RegisterWebAPI) router.GET("/register", h.RegisterWebAPI)
router.POST("/machine/:id/map", h.PollNetMapHandler) router.POST("/machine/:id/map", h.PollNetMapHandler)
router.POST("/machine/:id", h.RegistrationHandler) router.POST("/machine/:id", h.RegistrationHandler)
router.GET("/oidc/register/:mkey", h.RegisterOIDC) router.GET("/oidc/register/:nkey", h.RegisterOIDC)
router.GET("/oidc/callback", h.OIDCCallback) router.GET("/oidc/callback", h.OIDCCallback)
router.GET("/apple", h.AppleConfigMessage) router.GET("/apple", h.AppleConfigMessage)
router.GET("/apple/:platform", h.ApplePlatformConfig) router.GET("/apple/:platform", h.ApplePlatformConfig)
@@ -427,6 +441,15 @@ func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *gin.Engine {
return router return router
} }
func (h *Headscale) createNoiseMux() *http.ServeMux {
mux := http.NewServeMux()
mux.HandleFunc("/machine/register", h.NoiseRegistrationHandler)
mux.HandleFunc("/machine/map", h.NoisePollNetMapHandler)
return mux
}
// Serve launches a GIN server with the Headscale API. // Serve launches a GIN server with the Headscale API.
func (h *Headscale) Serve() error { func (h *Headscale) Serve() error {
var err error var err error
@@ -579,8 +602,14 @@ func (h *Headscale) Serve() error {
// HTTP setup // HTTP setup
// //
// This is the regular router that we expose
// over our main Addr. It also serves the legacy Tailcale API
router := h.createRouter(grpcGatewayMux) router := h.createRouter(grpcGatewayMux)
// This router is served only over the Noise connection,
// and exposes only the new API
h.noiseMux = h.createNoiseMux()
httpServer := &http.Server{ httpServer := &http.Server{
Addr: h.cfg.Addr, Addr: h.cfg.Addr,
Handler: router, Handler: router,
@@ -718,6 +747,10 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
// Configuration via autocert with HTTP-01. This requires listening on // Configuration via autocert with HTTP-01. This requires listening on
// port 80 for the certificate validation in addition to the headscale // port 80 for the certificate validation in addition to the headscale
// service, which can be configured to run on any other port. // service, which can be configured to run on any other port.
httpRouter := gin.Default()
httpRouter.POST(ts2021UpgradePath, h.NoiseUpgradeHandler)
httpRouter.NoRoute(gin.WrapF(h.redirect))
go func() { go func() {
log.Fatal(). log.Fatal().
Caller(). Caller().

View File

@@ -41,6 +41,13 @@ grpc_allow_insecure: false
# autogenerated if it's missing # autogenerated if it's missing
private_key_path: /var/lib/headscale/private.key private_key_path: /var/lib/headscale/private.key
# The Noise private key is used to encrypt the
# traffic between headscale and Tailscale clients when
# using the new Noise-based TS2021 protocol.
# The noise private key file which will be
# autogenerated if it's missing
noise_private_key_path: /var/lib/headscale/noise_private.key
# List of IP prefixes to allocate tailaddresses from. # List of IP prefixes to allocate tailaddresses from.
# Each prefix consists of either an IPv4 or IPv6 address, # Each prefix consists of either an IPv4 or IPv6 address,
# and the associated prefix length, delimited by a slash. # and the associated prefix length, delimited by a slash.

View File

@@ -28,6 +28,7 @@ type Config struct {
EphemeralNodeInactivityTimeout time.Duration EphemeralNodeInactivityTimeout time.Duration
IPPrefixes []netaddr.IPPrefix IPPrefixes []netaddr.IPPrefix
PrivateKeyPath string PrivateKeyPath string
NoisePrivateKeyPath string
BaseDomain string BaseDomain string
LogLevel zerolog.Level LogLevel zerolog.Level
DisableUpdateCheck bool DisableUpdateCheck bool
@@ -455,6 +456,9 @@ func GetHeadscaleConfig() (*Config, error) {
PrivateKeyPath: AbsolutePathFromConfigPath( PrivateKeyPath: AbsolutePathFromConfigPath(
viper.GetString("private_key_path"), viper.GetString("private_key_path"),
), ),
NoisePrivateKeyPath: AbsolutePathFromConfigPath(
viper.GetString("noise_private_key_path"),
),
BaseDomain: baseDomain, BaseDomain: baseDomain,
DERP: derpConfig, DERP: derpConfig,

View File

@@ -37,6 +37,7 @@ oidc:
- email - email
strip_email_domain: true strip_email_domain: true
private_key_path: private.key private_key_path: private.key
noise_private_key_path: noise_private.key
server_url: http://headscale:18080 server_url: http://headscale:18080
tls_client_auth_mode: relaxed tls_client_auth_mode: relaxed
tls_letsencrypt_cache_dir: /var/www/.cache tls_letsencrypt_cache_dir: /var/www/.cache

View File

@@ -13,6 +13,7 @@ dns_config:
- 1.1.1.1 - 1.1.1.1
db_path: /tmp/integration_test_db.sqlite3 db_path: /tmp/integration_test_db.sqlite3
private_key_path: private.key private_key_path: private.key
noise_private_key_path: noise_private.key
listen_addr: 0.0.0.0:18080 listen_addr: 0.0.0.0:18080
metrics_listen_addr: 127.0.0.1:19090 metrics_listen_addr: 127.0.0.1:19090
server_url: http://headscale:18080 server_url: http://headscale:18080

View File

@@ -37,6 +37,7 @@ oidc:
- email - email
strip_email_domain: true strip_email_domain: true
private_key_path: private.key private_key_path: private.key
noise_private_key_path: noise_private.key
server_url: http://headscale:8080 server_url: http://headscale:8080
tls_client_auth_mode: relaxed tls_client_auth_mode: relaxed
tls_letsencrypt_cache_dir: /var/www/.cache tls_letsencrypt_cache_dir: /var/www/.cache

View File

@@ -13,8 +13,9 @@ dns_config:
- 1.1.1.1 - 1.1.1.1
db_path: /tmp/integration_test_db.sqlite3 db_path: /tmp/integration_test_db.sqlite3
private_key_path: private.key private_key_path: private.key
listen_addr: 0.0.0.0:8443 noise_private_key_path: noise_private.key
server_url: https://headscale:8443 listen_addr: 0.0.0.0:443
server_url: https://headscale:443
tls_cert_path: "/etc/headscale/tls/server.crt" tls_cert_path: "/etc/headscale/tls/server.crt"
tls_key_path: "/etc/headscale/tls/server.key" tls_key_path: "/etc/headscale/tls/server.key"
tls_client_auth_mode: disabled tls_client_auth_mode: disabled

View File

@@ -349,7 +349,7 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
return &m, nil return &m, nil
} }
// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct. // GetMachineByMachineKey finds a Machine by its MachineKey and returns the Machine struct.
func (h *Headscale) GetMachineByMachineKey( func (h *Headscale) GetMachineByMachineKey(
machineKey key.MachinePublic, machineKey key.MachinePublic,
) (*Machine, error) { ) (*Machine, error) {
@@ -361,6 +361,19 @@ func (h *Headscale) GetMachineByMachineKey(
return &m, nil return &m, nil
} }
// GetMachineByNodeKeys finds a Machine by its current NodeKey or the old one, and returns the Machine struct.
func (h *Headscale) GetMachineByNodeKeys(
nodeKey key.NodePublic, oldNodeKey key.NodePublic,
) (*Machine, error) {
machine := Machine{}
if result := h.db.Preload("Namespace").First(&machine, "node_key = ? OR node_key = ?",
NodePublicKeyStripPrefix(nodeKey), NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil {
return nil, result.Error
}
return &machine, nil
}
// UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database // UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database
// and updates it with the latest data from the database. // and updates it with the latest data from the database.
func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error { func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error {
@@ -567,11 +580,14 @@ func (machine Machine) toNode(
} }
var machineKey key.MachinePublic var machineKey key.MachinePublic
err = machineKey.UnmarshalText( if machine.MachineKey != "" {
[]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)), // MachineKey is only used in the legacy protocol
) err = machineKey.UnmarshalText(
if err != nil { []byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)),
return nil, fmt.Errorf("failed to parse machine public key: %w", err) )
if err != nil {
return nil, fmt.Errorf("failed to parse machine public key: %w", err)
}
} }
var discoKey key.DiscoPublic var discoKey key.DiscoPublic
@@ -750,11 +766,11 @@ func getTags(
} }
func (h *Headscale) RegisterMachineFromAuthCallback( func (h *Headscale) RegisterMachineFromAuthCallback(
machineKeyStr string, nodeKeyStr string,
namespaceName string, namespaceName string,
registrationMethod string, registrationMethod string,
) (*Machine, error) { ) (*Machine, error) {
if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok { if machineInterface, ok := h.registrationCache.Get(nodeKeyStr); ok {
if registrationMachine, ok := machineInterface.(Machine); ok { if registrationMachine, ok := machineInterface.(Machine); ok {
namespace, err := h.GetNamespace(namespaceName) namespace, err := h.GetNamespace(namespaceName)
if err != nil { if err != nil {
@@ -785,7 +801,7 @@ func (h *Headscale) RegisterMachine(machine Machine,
) (*Machine, error) { ) (*Machine, error) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine_key", machine.MachineKey). Str("node_key", machine.NodeKey).
Msg("Registering machine") Msg("Registering machine")
log.Trace(). log.Trace().

View File

@@ -11,6 +11,7 @@ import (
"gopkg.in/check.v1" "gopkg.in/check.v1"
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key"
) )
func (s *Suite) TestGetMachine(c *check.C) { func (s *Suite) TestGetMachine(c *check.C) {
@@ -65,6 +66,35 @@ func (s *Suite) TestGetMachineByID(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
} }
func (s *Suite) TestGetMachineByNodeKeys(c *check.C) {
namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = app.GetMachineByID(0)
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
oldNodeKey := key.NewNode()
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()),
DiscoKey: "faa",
Hostname: "testmachine",
NamespaceID: namespace.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
app.db.Save(&machine)
_, err = app.GetMachineByNodeKeys(nodeKey.Public(), oldNodeKey.Public())
c.Assert(err, check.IsNil)
}
func (s *Suite) TestDeleteMachine(c *check.C) { func (s *Suite) TestDeleteMachine(c *check.C) {
namespace, err := app.CreateNamespace("test") namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)

125
noise.go Normal file
View File

@@ -0,0 +1,125 @@
package headscale
import (
"encoding/base64"
"net/http"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"tailscale.com/control/controlbase"
"tailscale.com/net/netutil"
)
const (
errWrongConnectionUpgrade = Error("wrong connection upgrade")
errCannotHijack = Error("cannot hijack connection")
errNoiseHandshakeFailed = Error("noise handshake failed")
)
const (
// ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade.
ts2021UpgradePath = "/ts2021"
// upgradeHeader is the value of the Upgrade HTTP header used to
// indicate the Tailscale control protocol.
upgradeHeaderValue = "tailscale-control-protocol"
// handshakeHeaderName is the HTTP request header that can
// optionally contain base64-encoded initial handshake
// payload, to save an RTT.
handshakeHeaderName = "X-Tailscale-Handshake"
)
// NoiseUpgradeHandler is to upgrade the connection and hijack the net.Conn
// in order to use the Noise-based TS2021 protocol. Listens in /ts2021.
func (h *Headscale) NoiseUpgradeHandler(ctx *gin.Context) {
log.Trace().Caller().Msgf("Noise upgrade handler for client %s", ctx.ClientIP())
// Under normal circumpstances, we should be able to use the controlhttp.AcceptHTTP()
// function to do this - kindly left there by the Tailscale authors for us to use.
// (https://github.com/tailscale/tailscale/blob/main/control/controlhttp/server.go)
//
// However, Gin seems to be doing something funny/different with its writer (see AcceptHTTP code).
// This causes problems when the upgrade headers are sent in AcceptHTTP.
// So have getNoiseConnection() that is essentially an AcceptHTTP but using the native Gin methods.
noiseConn, err := h.getNoiseConnection(ctx)
if err != nil {
log.Error().Err(err).Msg("noise upgrade failed")
ctx.AbortWithError(http.StatusInternalServerError, err)
return
}
server := http.Server{}
server.Handler = h2c.NewHandler(h.noiseMux, &http2.Server{})
server.Serve(netutil.NewOneConnListener(noiseConn, nil))
}
// getNoiseConnection is basically AcceptHTTP from tailscale, but more _alla_ Gin
// TODO(juan): Figure out why we need to do this at all.
func (h *Headscale) getNoiseConnection(ctx *gin.Context) (*controlbase.Conn, error) {
next := ctx.GetHeader("Upgrade")
if next == "" {
ctx.String(http.StatusBadRequest, "missing next protocol")
return nil, errWrongConnectionUpgrade
}
if next != upgradeHeaderValue {
ctx.String(http.StatusBadRequest, "unknown next protocol")
return nil, errWrongConnectionUpgrade
}
initB64 := ctx.GetHeader(handshakeHeaderName)
if initB64 == "" {
ctx.String(http.StatusBadRequest, "missing Tailscale handshake header")
return nil, errWrongConnectionUpgrade
}
init, err := base64.StdEncoding.DecodeString(initB64)
if err != nil {
ctx.String(http.StatusBadRequest, "invalid tailscale handshake header")
return nil, errWrongConnectionUpgrade
}
hijacker, ok := ctx.Writer.(http.Hijacker)
if !ok {
log.Error().Caller().Err(err).Msgf("Hijack failed")
ctx.String(http.StatusInternalServerError, "HTTP does not support general TCP support")
return nil, errCannotHijack
}
// This is what changes from the original AcceptHTTP() function.
ctx.Header("Upgrade", upgradeHeaderValue)
ctx.Header("Connection", "upgrade")
ctx.Status(http.StatusSwitchingProtocols)
ctx.Writer.WriteHeaderNow()
// end
netConn, conn, err := hijacker.Hijack()
if err != nil {
log.Error().Caller().Err(err).Msgf("Hijack failed")
ctx.String(http.StatusInternalServerError, "HTTP does not support general TCP support")
return nil, errCannotHijack
}
if err := conn.Flush(); err != nil {
netConn.Close()
return nil, errCannotHijack
}
netConn = netutil.NewDrainBufConn(netConn, conn.Reader)
nc, err := controlbase.Server(ctx.Request.Context(), netConn, *h.noisePrivateKey, init)
if err != nil {
netConn.Close()
return nil, errNoiseHandshakeFailed
}
return nc, nil
}

389
noise_api.go Normal file
View File

@@ -0,0 +1,389 @@
package headscale
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
)
func (h *Headscale) NoiseRegistrationHandler(
w http.ResponseWriter,
r *http.Request,
) {
log.Trace().Caller().Msgf("Noise registration handler for client %s", r.RemoteAddr)
if r.Method != http.MethodPost {
http.Error(w, "Wrong method", http.StatusMethodNotAllowed)
return
}
body, _ := io.ReadAll(r.Body)
req := tailcfg.RegisterRequest{}
if err := json.Unmarshal(body, &req); err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse RegisterRequest")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
http.Error(w, "Internal error", http.StatusInternalServerError)
return
}
log.Info().Caller().
Str("nodekey", req.NodeKey.ShortString()).
Str("oldnodekey", req.OldNodeKey.ShortString()).Msg("Nodekys!")
now := time.Now().UTC()
machine, err := h.GetMachineByNodeKeys(req.NodeKey, req.OldNodeKey)
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine via Noise")
// If the machine has AuthKey set, handle registration via PreAuthKeys
if req.Auth.AuthKey != "" {
h.handleNoiseAuthKey(w, r, req)
return
}
givenName, err := h.GenerateGivenName(req.Hostinfo.Hostname)
if err != nil {
log.Error().
Caller().
Str("func", "RegistrationHandler").
Str("hostinfo.name", req.Hostinfo.Hostname).
Err(err)
return
}
// The machine did not have a key to authenticate, which means
// that we rely on a method that calls back some how (OpenID or CLI)
// We create the machine and then keep it around until a callback
// happens
newMachine := Machine{
MachineKey: "",
Hostname: req.Hostinfo.Hostname,
GivenName: givenName,
NodeKey: NodePublicKeyStripPrefix(req.NodeKey),
LastSeen: &now,
Expiry: &time.Time{},
}
if !req.Expiry.IsZero() {
log.Trace().
Caller().
Str("machine", req.Hostinfo.Hostname).
Time("expiry", req.Expiry).
Msg("Non-zero expiry time requested")
newMachine.Expiry = &req.Expiry
}
h.registrationCache.Set(
NodePublicKeyStripPrefix(req.NodeKey),
newMachine,
registerCacheExpiration,
)
h.handleNoiseMachineRegistrationNew(w, r, req)
return
}
// The machine is already registered, so we need to pass through reauth or key update.
if machine != nil {
// If the NodeKey stored in headscale is the same as the key presented in a registration
// request, then we have a node that is either:
// - Trying to log out (sending a expiry in the past)
// - A valid, registered machine, looking for the node map
// - Expired machine wanting to reauthenticate
if machine.NodeKey == NodePublicKeyStripPrefix(req.NodeKey) {
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) {
h.handleNoiseNodeLogOut(w, r, *machine)
return
}
// If machine is not expired, and is register, we have a already accepted this machine,
// let it proceed with a valid registration
if !machine.isExpired() {
h.handleNoiseNodeValidRegistration(w, r, *machine)
return
}
}
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
if machine.NodeKey == NodePublicKeyStripPrefix(req.OldNodeKey) &&
!machine.isExpired() {
h.handleNoiseNodeRefreshKey(w, r, req, *machine)
return
}
// The node has expired
h.handleNoiseNodeExpired(w, r, req, *machine)
return
}
}
func (h *Headscale) handleNoiseAuthKey(
w http.ResponseWriter,
r *http.Request,
registerRequest tailcfg.RegisterRequest,
) {
log.Debug().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Msgf("Processing auth key for %s over Noise", registerRequest.Hostinfo.Hostname)
resp := tailcfg.RegisterResponse{}
pak, err := h.checkKeyValidity(registerRequest.Auth.AuthKey)
if err != nil {
log.Error().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Err(err).
Msg("Failed authentication via AuthKey")
resp.MachineAuthorized = false
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(resp)
log.Error().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Failed authentication via AuthKey over Noise")
if pak != nil {
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc()
} else {
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", "unknown").Inc()
}
return
}
log.Debug().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Authentication key was valid, proceeding to acquire IP addresses")
nodeKey := NodePublicKeyStripPrefix(registerRequest.NodeKey)
// retrieve machine information if it exist
// The error is not important, because if it does not
// exist, then this is a new machine and we will move
// on to registration.
machine, _ := h.GetMachineByNodeKeys(registerRequest.NodeKey, registerRequest.OldNodeKey)
if machine != nil {
log.Trace().
Caller().
Str("machine", machine.Hostname).
Msg("machine already registered, refreshing with new auth key")
machine.NodeKey = nodeKey
machine.AuthKeyID = uint(pak.ID)
h.RefreshMachine(machine, registerRequest.Expiry)
} else {
now := time.Now().UTC()
givenName, err := h.GenerateGivenName(registerRequest.Hostinfo.Hostname)
if err != nil {
log.Error().
Caller().
Str("func", "RegistrationHandler").
Str("hostinfo.name", registerRequest.Hostinfo.Hostname).
Err(err)
return
}
machineToRegister := Machine{
Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName,
NamespaceID: pak.Namespace.ID,
MachineKey: "",
RegisterMethod: RegisterMethodAuthKey,
Expiry: &registerRequest.Expiry,
NodeKey: nodeKey,
LastSeen: &now,
AuthKeyID: uint(pak.ID),
}
machine, err = h.RegisterMachine(
machineToRegister,
)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("could not register machine")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc()
http.Error(w, "Internal error", http.StatusInternalServerError)
return
}
}
h.UsePreAuthKey(pak)
resp.MachineAuthorized = true
resp.User = *pak.Namespace.toUser()
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.Namespace.Name).
Inc()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(resp)
log.Info().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Str("ips", strings.Join(machine.IPAddresses.ToStringSlice(), ", ")).
Msg("Successfully authenticated via AuthKey on Noise")
}
func (h *Headscale) handleNoiseNodeValidRegistration(
w http.ResponseWriter,
r *http.Request,
machine Machine,
) {
resp := tailcfg.RegisterResponse{}
// The machine registration is valid, respond with redirect to /map
log.Debug().
Str("machine", machine.Hostname).
Msg("Client is registered and we have the current NodeKey. All clear to /map")
resp.AuthURL = ""
resp.MachineAuthorized = true
resp.User = *machine.Namespace.toUser()
resp.Login = *machine.Namespace.toLogin()
machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name).
Inc()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(resp)
}
func (h *Headscale) handleNoiseMachineRegistrationNew(
w http.ResponseWriter,
r *http.Request,
registerRequest tailcfg.RegisterRequest,
) {
resp := tailcfg.RegisterResponse{}
// The machine registration is new, redirect the client to the registration URL
log.Debug().
Str("machine", registerRequest.Hostinfo.Hostname).
Msg("The node is sending us a new NodeKey, sending auth url")
if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf(
"%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
NodePublicKeyStripPrefix(registerRequest.NodeKey),
)
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), NodePublicKeyStripPrefix(registerRequest.NodeKey))
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(resp)
}
func (h *Headscale) handleNoiseNodeLogOut(
w http.ResponseWriter,
r *http.Request,
machine Machine,
) {
resp := tailcfg.RegisterResponse{}
log.Info().
Str("machine", machine.Hostname).
Msg("Client requested logout")
h.ExpireMachine(&machine)
resp.AuthURL = ""
resp.MachineAuthorized = false
resp.User = *machine.Namespace.toUser()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(resp)
}
func (h *Headscale) handleNoiseNodeRefreshKey(
w http.ResponseWriter,
r *http.Request,
registerRequest tailcfg.RegisterRequest,
machine Machine,
) {
resp := tailcfg.RegisterResponse{}
log.Debug().
Str("machine", machine.Hostname).
Msg("We have the OldNodeKey in the database. This is a key refresh")
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
h.db.Save(&machine)
resp.AuthURL = ""
resp.User = *machine.Namespace.toUser()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(resp)
}
func (h *Headscale) handleNoiseNodeExpired(
w http.ResponseWriter,
r *http.Request,
registerRequest tailcfg.RegisterRequest,
machine Machine,
) {
resp := tailcfg.RegisterResponse{}
// The client has registered before, but has expired
log.Debug().
Caller().
Str("machine", machine.Hostname).
Msg("Machine registration has expired. Sending a authurl to register")
if registerRequest.Auth.AuthKey != "" {
h.handleNoiseAuthKey(w, r, registerRequest)
return
}
if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), machine.NodeKey)
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), machine.NodeKey)
}
machineRegistrations.WithLabelValues("reauth", "web", "success", machine.Namespace.Name).
Inc()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(resp)
}

737
noise_poll.go Normal file
View File

@@ -0,0 +1,737 @@
package headscale
import (
"context"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"
"github.com/klauspost/compress/zstd"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
// NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
//
// The clients POST stuff like HostInfo and their Endpoints here, but
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) NoisePollNetMapHandler(
w http.ResponseWriter,
r *http.Request,
) {
log.Trace().
Str("handler", "NoisePollNetMap").
Msg("PollNetMapHandler called")
body, _ := io.ReadAll(r.Body)
req := tailcfg.MapRequest{}
if err := json.Unmarshal(body, &req); err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse MapRequest")
http.Error(w, "Internal error", http.StatusInternalServerError)
return
}
machine, err := h.GetMachineByNodeKeys(req.NodeKey, key.NodePublic{})
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "NoisePollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", req.NodeKey.String())
http.Error(w, "Internal error", http.StatusNotFound)
return
}
log.Error().
Str("handler", "NoisePollNetMap").
Msgf("Failed to fetch machine from the database with node key: %s", req.NodeKey.String())
http.Error(w, "Internal error", http.StatusInternalServerError)
return
}
log.Trace().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Msg("Found machine in database")
machine.Hostname = req.Hostinfo.Hostname
machine.HostInfo = HostInfo(*req.Hostinfo)
machine.DiscoKey = DiscoPublicKeyStripPrefix(req.DiscoKey)
now := time.Now().UTC()
// update ACLRules with peer informations (to update server tags if necessary)
if h.aclPolicy != nil {
err = h.UpdateACLRules()
if err != nil {
log.Error().
Caller().
Str("func", "handleAuthKey").
Str("machine", machine.Hostname).
Err(err)
}
}
// From Tailscale client:
//
// ReadOnly is whether the client just wants to fetch the MapResponse,
// without updating their Endpoints. The Endpoints field will be ignored and
// LastSeen will not be updated and peers will not be notified of changes.
//
// The intended use is for clients to discover the DERP map at start-up
// before their first real endpoint update.
if !req.ReadOnly {
machine.Endpoints = req.Endpoints
machine.LastSeen = &now
}
if err := h.db.Updates(machine).Error; err != nil {
if err != nil {
log.Error().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Err(err).
Msg("Failed to persist/update machine in the database")
http.Error(w, "Internal error", http.StatusInternalServerError)
return
}
}
resp, err := h.getNoiseMapResponse(req, machine)
if err != nil {
log.Error().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Err(err).
Msg("Failed to get Map response")
http.Error(w, "Internal error", http.StatusInternalServerError)
return
}
// We update our peers if the client is not sending ReadOnly in the MapRequest
// so we don't distribute its initial request (it comes with
// empty endpoints to peers)
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
log.Debug().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Bool("readOnly", req.ReadOnly).
Bool("omitPeers", req.OmitPeers).
Bool("stream", req.Stream).
Msg("Noise client map request processed")
if req.ReadOnly {
log.Info().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Msg("Client is starting up. Probably interested in a DERP map")
// w.Header().Set("Content-Type", "application/json")
// w.WriteHeader(http.StatusOK)
_, err = w.Write(resp)
if err != nil {
log.Warn().Msgf("Could not send JSON response: %s", err)
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
log.Info().Msgf("Noise client map response sent for %s (len %d)", machine.Hostname, len(resp))
return
}
// There has been an update to _any_ of the nodes that the other nodes would
// need to know about
h.setLastStateChangeToNow(machine.Namespace.Name)
// The request is not ReadOnly, so we need to set up channels for updating
// peers via longpoll
// Only create update channel if it has not been created
log.Trace().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Msg("Loading or creating update channel")
const chanSize = 8
updateChan := make(chan struct{}, chanSize)
pollDataChan := make(chan []byte, chanSize)
defer closeChanWithLog(pollDataChan, machine.Hostname, "pollDataChan")
keepAliveChan := make(chan []byte)
if req.OmitPeers && !req.Stream {
log.Info().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Msg("Client sent endpoint update and is ok with a response without peer list")
w.Write(resp)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
// It sounds like we should update the nodes when we have received a endpoint update
// even tho the comments in the tailscale code dont explicitly say so.
updateRequestsFromNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "endpoint-update").
Inc()
updateChan <- struct{}{}
return
} else if req.OmitPeers && req.Stream {
log.Warn().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Msg("Ignoring request, don't know how to handle it")
http.Error(w, "Internal error", http.StatusBadRequest)
return
}
log.Info().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Msg("Client is ready to access the tailnet")
log.Info().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Msg("Sending initial map")
pollDataChan <- resp
log.Info().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Msg("Notifying peers")
updateRequestsFromNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "full-update").
Inc()
updateChan <- struct{}{}
h.NoisePollNetMapStream(
w,
r,
machine,
req,
pollDataChan,
keepAliveChan,
updateChan,
)
log.Trace().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Msg("Finished stream, closing PollNetMap session")
}
// PollNetMapStream takes care of /machine/:id/map
// stream logic, ensuring we communicate updates and data
// to the connected clients.
func (h *Headscale) NoisePollNetMapStream(
w http.ResponseWriter,
r *http.Request,
machine *Machine,
mapRequest tailcfg.MapRequest,
pollDataChan chan []byte,
keepAliveChan chan []byte,
updateChan chan struct{},
) {
ctx := context.WithValue(context.Background(), machineNameContextKey, machine.Hostname)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
go h.noiseScheduledPollWorker(
ctx,
updateChan,
keepAliveChan,
mapRequest,
machine,
)
for {
log.Trace().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Msg("Waiting for data to stream...")
log.Trace().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan)
select {
case data := <-pollDataChan:
log.Trace().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Sending data received via pollData channel")
_, err := w.Write(data)
if err != nil {
log.Error().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "pollData").
Err(err).
Msg("Cannot write data")
break
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
log.Trace().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Data from pollData channel written successfully")
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachineFromDatabase(machine)
if err != nil {
log.Error().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "pollData").
Err(err).
Msg("Cannot update machine from database")
// client has been removed from database
// since the stream opened, terminate connection.
break
}
now := time.Now().UTC()
machine.LastSeen = &now
lastStateUpdate.WithLabelValues(machine.Namespace.Name, machine.Hostname).
Set(float64(now.Unix()))
machine.LastSuccessfulUpdate = &now
err = h.TouchMachine(machine)
if err != nil {
log.Error().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "pollData").
Err(err).
Msg("Cannot update machine LastSuccessfulUpdate")
} else {
log.Trace().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Machine entry in database updated successfully after sending pollData")
}
break
case data := <-keepAliveChan:
log.Trace().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Sending keep alive message")
_, err := w.Write(data)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
if err != nil {
log.Error().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Err(err).
Msg("Cannot write keep alive message")
break
}
log.Trace().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Keep alive sent successfully")
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachineFromDatabase(machine)
if err != nil {
log.Error().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Err(err).
Msg("Cannot update machine from database")
// client has been removed from database
// since the stream opened, terminate connection.
break
}
now := time.Now().UTC()
machine.LastSeen = &now
err = h.TouchMachine(machine)
if err != nil {
log.Error().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Err(err).
Msg("Cannot update machine LastSeen")
} else {
log.Trace().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Machine updated successfully after sending keep alive")
}
break
case <-updateChan:
log.Trace().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "update").
Msg("Received a request for update")
updateRequestsReceivedOnChannel.WithLabelValues(machine.Namespace.Name, machine.Hostname).
Inc()
if h.isOutdated(machine) {
var lastUpdate time.Time
if machine.LastSuccessfulUpdate != nil {
lastUpdate = *machine.LastSuccessfulUpdate
}
log.Debug().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Time("last_successful_update", lastUpdate).
Time("last_state_change", h.getLastStateChange(machine.Namespace.Name)).
Msgf("There has been updates since the last successful update to %s", machine.Hostname)
data, err := h.getNoiseMapResponse(mapRequest, machine)
if err != nil {
log.Error().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "update").
Err(err).
Msg("Could not get the map update")
}
_, err = w.Write(data)
if err != nil {
log.Error().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "update").
Err(err).
Msg("Could not write the map response")
updateRequestsSentToNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "failed").
Inc()
break
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
log.Trace().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "update").
Msg("Updated Map has been sent")
updateRequestsSentToNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "success").
Inc()
// Keep track of the last successful update,
// we sometimes end in a state were the update
// is not picked up by a client and we use this
// to determine if we should "force" an update.
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachineFromDatabase(machine)
if err != nil {
log.Error().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "update").
Err(err).
Msg("Cannot update machine from database")
// client has been removed from database
// since the stream opened, terminate connection.
break
}
now := time.Now().UTC()
lastStateUpdate.WithLabelValues(machine.Namespace.Name, machine.Hostname).
Set(float64(now.Unix()))
machine.LastSuccessfulUpdate = &now
err = h.TouchMachine(machine)
if err != nil {
log.Error().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "update").
Err(err).
Msg("Cannot update machine LastSuccessfulUpdate")
}
} else {
var lastUpdate time.Time
if machine.LastSuccessfulUpdate != nil {
lastUpdate = *machine.LastSuccessfulUpdate
}
log.Trace().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Time("last_successful_update", lastUpdate).
Time("last_state_change", h.getLastStateChange(machine.Namespace.Name)).
Msgf("%s is up to date", machine.Hostname)
}
break
case <-ctx.Done():
log.Info().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Msg("The client has closed the connection")
// TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err := h.UpdateMachineFromDatabase(machine)
if err != nil {
log.Error().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "Done").
Err(err).
Msg("Cannot update machine from database")
// client has been removed from database
// since the stream opened, terminate connection.
break
}
now := time.Now().UTC()
machine.LastSeen = &now
err = h.TouchMachine(machine)
if err != nil {
log.Error().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "Done").
Err(err).
Msg("Cannot update machine LastSeen")
}
break
}
}
}
func (h *Headscale) noiseScheduledPollWorker(
ctx context.Context,
updateChan chan struct{},
keepAliveChan chan []byte,
mapRequest tailcfg.MapRequest,
machine *Machine,
) {
keepAliveTicker := time.NewTicker(keepAliveInterval)
updateCheckerTicker := time.NewTicker(updateCheckInterval)
defer closeChanWithLog(
updateChan,
fmt.Sprint(ctx.Value(machineNameContextKey)),
"updateChan",
)
defer closeChanWithLog(
keepAliveChan,
fmt.Sprint(ctx.Value(machineNameContextKey)),
"updateChan",
)
for {
select {
case <-ctx.Done():
return
case <-keepAliveTicker.C:
data, err := h.getNoiseMapKeepAliveResponse(mapRequest)
if err != nil {
log.Error().
Str("func", "keepAlive").
Err(err).
Msg("Error generating the keep alive msg")
return
}
log.Debug().
Str("func", "keepAlive").
Str("machine", machine.Hostname).
Msg("Sending keepalive")
keepAliveChan <- data
case <-updateCheckerTicker.C:
log.Debug().
Str("func", "scheduledPollWorker").
Str("machine", machine.Hostname).
Msg("Sending update request")
updateRequestsFromNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "scheduled-update").
Inc()
updateChan <- struct{}{}
}
}
}
func (h *Headscale) getNoiseMapKeepAliveResponse(req tailcfg.MapRequest) ([]byte, error) {
resp := tailcfg.MapResponse{
KeepAlive: true,
}
// The TS2021 protocol does not rely anymore on the machine key to
// encrypt in a NaCl box the map response. We just send it back
// unencrypted via the encrypted Noise channel.
// declare the incoming size on the first 4 bytes
respBody, err := json.Marshal(resp)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot marshal map response")
}
var srcCompressed []byte
if req.Compress == "zstd" {
encoder, _ := zstd.NewWriter(nil)
srcCompressed = encoder.EncodeAll(respBody, nil)
} else {
srcCompressed = respBody
}
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(srcCompressed)))
data = append(data, srcCompressed...)
return data, nil
}
func (h *Headscale) getNoiseMapResponse(
req tailcfg.MapRequest,
machine *Machine,
) ([]byte, error) {
log.Trace().
Str("func", "getNoiseMapResponse").
Str("machine", req.Hostinfo.Hostname).
Msg("Creating Map response")
node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
if err != nil {
log.Error().
Caller().
Str("func", "getNoiseMapResponse").
Err(err).
Msg("Cannot convert to node")
return nil, err
}
peers, err := h.getValidPeers(machine)
if err != nil {
log.Error().
Caller().
Str("func", "getNoiseMapResponse").
Err(err).
Msg("Cannot fetch peers")
return nil, err
}
profiles := getMapResponseUserProfiles(*machine, peers)
nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
if err != nil {
log.Error().
Caller().
Str("func", "getNoiseMapResponse").
Err(err).
Msg("Failed to convert peers to Tailscale nodes")
return nil, err
}
dnsConfig := getMapResponseDNSConfig(
h.cfg.DNSConfig,
h.cfg.BaseDomain,
*machine,
peers,
)
resp := tailcfg.MapResponse{
KeepAlive: false,
Node: node,
Peers: nodePeers,
DNSConfig: dnsConfig,
Domain: h.cfg.BaseDomain,
PacketFilter: h.aclRules,
DERPMap: h.DERPMap,
UserProfiles: profiles,
Debug: &tailcfg.Debug{
DisableLogTail: !h.cfg.LogTail.Enabled,
RandomizeClientPort: h.cfg.RandomizeClientPort,
},
}
log.Trace().
Str("func", "getNoiseMapResponse").
Str("machine", req.Hostinfo.Hostname).
Msgf("Generated map response: %s", tailMapResponseToString(resp))
// The TS2021 protocol does not rely anymore on the machine key to
// encrypt in a NaCl box the map response. We just send it back
// unencrypted via the encrypted Noise channel.
// declare the incoming size on the first 4 bytes
respBody, err := json.Marshal(resp)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot marshal map response")
}
var srcCompressed []byte
if req.Compress == "zstd" {
encoder, _ := zstd.NewWriter(nil)
srcCompressed = encoder.EncodeAll(respBody, nil)
} else {
srcCompressed = respBody
}
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(srcCompressed)))
data = append(data, srcCompressed...)
return data, nil
}

36
oidc.go
View File

@@ -62,10 +62,10 @@ func (h *Headscale) initOIDC() error {
// RegisterOIDC redirects to the OIDC provider for authentication // RegisterOIDC redirects to the OIDC provider for authentication
// Puts machine key in cache so the callback can retrieve it using the oidc state param // Puts machine key in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:mKey. // Listens in /oidc/register/:nKey.
func (h *Headscale) RegisterOIDC(ctx *gin.Context) { func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
machineKeyStr := ctx.Param("mkey") nodeKeyStr := ctx.Param("nkey")
if machineKeyStr == "" { if nodeKeyStr == "" {
ctx.String(http.StatusBadRequest, "Wrong params") ctx.String(http.StatusBadRequest, "Wrong params")
return return
@@ -73,7 +73,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine_key", machineKeyStr). Str("node_key", nodeKeyStr).
Msg("Received oidc register call") Msg("Received oidc register call")
randomBlob := make([]byte, randomByteSize) randomBlob := make([]byte, randomByteSize)
@@ -89,7 +89,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
stateStr := hex.EncodeToString(randomBlob)[:32] stateStr := hex.EncodeToString(randomBlob)[:32]
// place the machine key into the state cache, so it can be retrieved later // place the machine key into the state cache, so it can be retrieved later
h.registrationCache.Set(stateStr, machineKeyStr, registerCacheExpiration) h.registrationCache.Set(stateStr, nodeKeyStr, registerCacheExpiration)
// Add any extra parameter provided in the configuration to the Authorize Endpoint request // Add any extra parameter provided in the configuration to the Authorize Endpoint request
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams)) extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
@@ -217,10 +217,10 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
// retrieve machinekey from state cache // retrieve nodekey from state cache
machineKeyIf, machineKeyFound := h.registrationCache.Get(state) nodeKeyIf, nodeKeyFound := h.registrationCache.Get(state)
if !machineKeyFound { if !nodeKeyFound {
log.Error(). log.Error().
Msg("requested machine state key expired before authorisation completed") Msg("requested machine state key expired before authorisation completed")
ctx.String(http.StatusBadRequest, "state has expired") ctx.String(http.StatusBadRequest, "state has expired")
@@ -228,22 +228,22 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
machineKeyFromCache, machineKeyOK := machineKeyIf.(string) nodeKeyFromCache, nodeKeyOK := nodeKeyIf.(string)
var machineKey key.MachinePublic var nodeKey key.NodePublic
err = machineKey.UnmarshalText( err = nodeKey.UnmarshalText(
[]byte(MachinePublicKeyEnsurePrefix(machineKeyFromCache)), []byte(MachinePublicKeyEnsurePrefix(nodeKeyFromCache)),
) )
if err != nil { if err != nil {
log.Error(). log.Error().
Msg("could not parse machine public key") Msg("could not parse node public key")
ctx.String(http.StatusBadRequest, "could not parse public key") ctx.String(http.StatusBadRequest, "could not parse public key")
return return
} }
if !machineKeyOK { if !nodeKeyOK {
log.Error().Msg("could not get machine key from cache") log.Error().Msg("could not get node key from cache")
ctx.String( ctx.String(
http.StatusInternalServerError, http.StatusInternalServerError,
"could not get machine key from cache", "could not get machine key from cache",
@@ -256,7 +256,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
// The error is not important, because if it does not // The error is not important, because if it does not
// exist, then this is a new machine and we will move // exist, then this is a new machine and we will move
// on to registration. // on to registration.
machine, _ := h.GetMachineByMachineKey(machineKey) machine, _ := h.GetMachineByNodeKeys(nodeKey, key.NodePublic{})
if machine != nil { if machine != nil {
log.Trace(). log.Trace().
@@ -335,10 +335,10 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
machineKeyStr := MachinePublicKeyStripPrefix(machineKey) nodeKeyStr := NodePublicKeyStripPrefix(nodeKey)
_, err = h.RegisterMachineFromAuthCallback( _, err = h.RegisterMachineFromAuthCallback(
machineKeyStr, nodeKeyStr,
namespace.Name, namespace.Name,
RegisterMethodOIDC, RegisterMethodOIDC,
) )