Compare commits

...

7 Commits
main ... hs2021

Author SHA1 Message Date
Juan Font Alonso
96b02f7d89 Updated test config to work with TS2021 2022-06-16 00:27:53 +02:00
Juan Font Alonso
7078d36dc6 Added MapPoll to Noise protocol 2022-06-16 00:21:46 +02:00
Juan Font Alonso
670c7d9144 TS2021: Add Noise endpoint for node registration 2022-06-12 14:33:47 +02:00
Juan Font Alonso
e8205e8d5a TS2021: Use NodeKey for everything, as MachineKey is deprecated in TS2021 2022-06-12 12:30:56 +02:00
Juan Font Alonso
b40b4e8d45 Added Noise upgrade handler and Noise mux 2022-06-11 19:08:35 +02:00
Juan Font Alonso
304987b4ff TS2021: Convert /key handler to send the Noise key too 2022-06-11 19:00:49 +02:00
Juan Font Alonso
c908627e68 Generate and read the Noise private key 2022-06-11 18:53:11 +02:00
14 changed files with 1415 additions and 40 deletions

34
api.go
View File

@@ -9,6 +9,7 @@ import (
"html/template"
"io"
"net/http"
"strconv"
"strings"
"time"
@@ -28,11 +29,40 @@ const (
ErrRegisterMethodCLIDoesNotSupportExpire = Error(
"machines registered with CLI does not support expire",
)
// The CapabilityVersion is used by Tailscale clients to indicate
// their codebase version. Tailscale clients can communicate over TS2021
// from CapabilityVersion 28.
// See https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go
NoiseCapabilityVersion = 28
)
// KeyHandler provides the Headscale pub key
// Listens in /key.
func (h *Headscale) KeyHandler(ctx *gin.Context) {
// New Tailscale clients send a 'v' parameter to indicate the CurrentCapabilityVersion
clientCapabilityStr := ctx.Query("v")
if clientCapabilityStr != "" {
clientCapabilityVersion, err := strconv.Atoi(clientCapabilityStr)
if err != nil {
ctx.String(http.StatusBadRequest, "Invalid version")
return
}
if clientCapabilityVersion >= NoiseCapabilityVersion {
// Tailscale has a different key for the TS2021 protocol
resp := tailcfg.OverTLSPublicKeyResponse{
LegacyPublicKey: h.privateKey.Public(),
PublicKey: h.noisePrivateKey.Public(),
}
ctx.JSON(http.StatusOK, resp)
return
}
}
// Old clients don't send a 'v' parameter, so we send the legacy public key
ctx.Data(
http.StatusOK,
"text/plain; charset=utf-8",
@@ -516,11 +546,11 @@ func (h *Headscale) handleMachineRegistrationNew(
resp.AuthURL = fmt.Sprintf(
"%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
machineKey.String(),
NodePublicKeyStripPrefix(registerRequest.NodeKey),
)
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey))
strings.TrimSuffix(h.cfg.ServerURL, "/"), NodePublicKeyStripPrefix(registerRequest.NodeKey))
}
respBody, err := encode(resp, &machineKey, h.privateKey)

51
app.go
View File

@@ -71,12 +71,15 @@ const (
// Headscale represents the base app of the service.
type Headscale struct {
cfg *Config
db *gorm.DB
dbString string
dbType string
dbDebug bool
privateKey *key.MachinePrivate
cfg *Config
db *gorm.DB
dbString string
dbType string
dbDebug bool
privateKey *key.MachinePrivate
noisePrivateKey *key.MachinePrivate
noiseMux *http.ServeMux
DERPMap *tailcfg.DERPMap
DERPServer *DERPServer
@@ -116,11 +119,20 @@ func LookupTLSClientAuthMode(mode string) (tls.ClientAuthType, bool) {
}
func NewHeadscale(cfg *Config) (*Headscale, error) {
privKey, err := readOrCreatePrivateKey(cfg.PrivateKeyPath)
privateKey, err := readOrCreatePrivateKey(cfg.PrivateKeyPath)
if err != nil {
return nil, fmt.Errorf("failed to read or create private key: %w", err)
}
noisePrivateKey, err := readOrCreatePrivateKey(cfg.NoisePrivateKeyPath)
if err != nil {
return nil, fmt.Errorf("failed to read or create noise private key: %w", err)
}
if privateKey.Equal(*noisePrivateKey) {
return nil, fmt.Errorf("private key and noise private key are the same")
}
var dbString string
switch cfg.DBtype {
case Postgres:
@@ -147,7 +159,8 @@ func NewHeadscale(cfg *Config) (*Headscale, error) {
cfg: cfg,
dbType: cfg.DBtype,
dbString: dbString,
privateKey: privKey,
privateKey: privateKey,
noisePrivateKey: noisePrivateKey,
aclRules: tailcfg.FilterAllowAll, // default allowall
registrationCache: registrationCache,
}
@@ -393,6 +406,7 @@ func (h *Headscale) createPrometheusRouter() *gin.Engine {
func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *gin.Engine {
router := gin.Default()
router.POST(ts2021UpgradePath, h.NoiseUpgradeHandler)
router.GET(
"/health",
func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) },
@@ -401,7 +415,7 @@ func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *gin.Engine {
router.GET("/register", h.RegisterWebAPI)
router.POST("/machine/:id/map", h.PollNetMapHandler)
router.POST("/machine/:id", h.RegistrationHandler)
router.GET("/oidc/register/:mkey", h.RegisterOIDC)
router.GET("/oidc/register/:nkey", h.RegisterOIDC)
router.GET("/oidc/callback", h.OIDCCallback)
router.GET("/apple", h.AppleConfigMessage)
router.GET("/apple/:platform", h.ApplePlatformConfig)
@@ -427,6 +441,15 @@ func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *gin.Engine {
return router
}
func (h *Headscale) createNoiseMux() *http.ServeMux {
mux := http.NewServeMux()
mux.HandleFunc("/machine/register", h.NoiseRegistrationHandler)
mux.HandleFunc("/machine/map", h.NoisePollNetMapHandler)
return mux
}
// Serve launches a GIN server with the Headscale API.
func (h *Headscale) Serve() error {
var err error
@@ -579,8 +602,14 @@ func (h *Headscale) Serve() error {
// HTTP setup
//
// This is the regular router that we expose
// over our main Addr. It also serves the legacy Tailcale API
router := h.createRouter(grpcGatewayMux)
// This router is served only over the Noise connection,
// and exposes only the new API
h.noiseMux = h.createNoiseMux()
httpServer := &http.Server{
Addr: h.cfg.Addr,
Handler: router,
@@ -718,6 +747,10 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
// Configuration via autocert with HTTP-01. This requires listening on
// port 80 for the certificate validation in addition to the headscale
// service, which can be configured to run on any other port.
httpRouter := gin.Default()
httpRouter.POST(ts2021UpgradePath, h.NoiseUpgradeHandler)
httpRouter.NoRoute(gin.WrapF(h.redirect))
go func() {
log.Fatal().
Caller().

View File

@@ -41,6 +41,13 @@ grpc_allow_insecure: false
# autogenerated if it's missing
private_key_path: /var/lib/headscale/private.key
# The Noise private key is used to encrypt the
# traffic between headscale and Tailscale clients when
# using the new Noise-based TS2021 protocol.
# The noise private key file which will be
# autogenerated if it's missing
noise_private_key_path: /var/lib/headscale/noise_private.key
# List of IP prefixes to allocate tailaddresses from.
# Each prefix consists of either an IPv4 or IPv6 address,
# and the associated prefix length, delimited by a slash.

View File

@@ -28,6 +28,7 @@ type Config struct {
EphemeralNodeInactivityTimeout time.Duration
IPPrefixes []netaddr.IPPrefix
PrivateKeyPath string
NoisePrivateKeyPath string
BaseDomain string
LogLevel zerolog.Level
DisableUpdateCheck bool
@@ -455,6 +456,9 @@ func GetHeadscaleConfig() (*Config, error) {
PrivateKeyPath: AbsolutePathFromConfigPath(
viper.GetString("private_key_path"),
),
NoisePrivateKeyPath: AbsolutePathFromConfigPath(
viper.GetString("noise_private_key_path"),
),
BaseDomain: baseDomain,
DERP: derpConfig,

View File

@@ -37,6 +37,7 @@ oidc:
- email
strip_email_domain: true
private_key_path: private.key
noise_private_key_path: noise_private.key
server_url: http://headscale:18080
tls_client_auth_mode: relaxed
tls_letsencrypt_cache_dir: /var/www/.cache

View File

@@ -13,6 +13,7 @@ dns_config:
- 1.1.1.1
db_path: /tmp/integration_test_db.sqlite3
private_key_path: private.key
noise_private_key_path: noise_private.key
listen_addr: 0.0.0.0:18080
metrics_listen_addr: 127.0.0.1:19090
server_url: http://headscale:18080

View File

@@ -37,6 +37,7 @@ oidc:
- email
strip_email_domain: true
private_key_path: private.key
noise_private_key_path: noise_private.key
server_url: http://headscale:8080
tls_client_auth_mode: relaxed
tls_letsencrypt_cache_dir: /var/www/.cache

View File

@@ -13,8 +13,9 @@ dns_config:
- 1.1.1.1
db_path: /tmp/integration_test_db.sqlite3
private_key_path: private.key
listen_addr: 0.0.0.0:8443
server_url: https://headscale:8443
noise_private_key_path: noise_private.key
listen_addr: 0.0.0.0:443
server_url: https://headscale:443
tls_cert_path: "/etc/headscale/tls/server.crt"
tls_key_path: "/etc/headscale/tls/server.key"
tls_client_auth_mode: disabled

View File

@@ -349,7 +349,7 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
return &m, nil
}
// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct.
// GetMachineByMachineKey finds a Machine by its MachineKey and returns the Machine struct.
func (h *Headscale) GetMachineByMachineKey(
machineKey key.MachinePublic,
) (*Machine, error) {
@@ -361,6 +361,19 @@ func (h *Headscale) GetMachineByMachineKey(
return &m, nil
}
// GetMachineByNodeKeys finds a Machine by its current NodeKey or the old one, and returns the Machine struct.
func (h *Headscale) GetMachineByNodeKeys(
nodeKey key.NodePublic, oldNodeKey key.NodePublic,
) (*Machine, error) {
machine := Machine{}
if result := h.db.Preload("Namespace").First(&machine, "node_key = ? OR node_key = ?",
NodePublicKeyStripPrefix(nodeKey), NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil {
return nil, result.Error
}
return &machine, nil
}
// UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database
// and updates it with the latest data from the database.
func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error {
@@ -567,11 +580,14 @@ func (machine Machine) toNode(
}
var machineKey key.MachinePublic
err = machineKey.UnmarshalText(
[]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)),
)
if err != nil {
return nil, fmt.Errorf("failed to parse machine public key: %w", err)
if machine.MachineKey != "" {
// MachineKey is only used in the legacy protocol
err = machineKey.UnmarshalText(
[]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)),
)
if err != nil {
return nil, fmt.Errorf("failed to parse machine public key: %w", err)
}
}
var discoKey key.DiscoPublic
@@ -750,11 +766,11 @@ func getTags(
}
func (h *Headscale) RegisterMachineFromAuthCallback(
machineKeyStr string,
nodeKeyStr string,
namespaceName string,
registrationMethod string,
) (*Machine, error) {
if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok {
if machineInterface, ok := h.registrationCache.Get(nodeKeyStr); ok {
if registrationMachine, ok := machineInterface.(Machine); ok {
namespace, err := h.GetNamespace(namespaceName)
if err != nil {
@@ -785,7 +801,7 @@ func (h *Headscale) RegisterMachine(machine Machine,
) (*Machine, error) {
log.Trace().
Caller().
Str("machine_key", machine.MachineKey).
Str("node_key", machine.NodeKey).
Msg("Registering machine")
log.Trace().

View File

@@ -11,6 +11,7 @@ import (
"gopkg.in/check.v1"
"inet.af/netaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
func (s *Suite) TestGetMachine(c *check.C) {
@@ -65,6 +66,35 @@ func (s *Suite) TestGetMachineByID(c *check.C) {
c.Assert(err, check.IsNil)
}
func (s *Suite) TestGetMachineByNodeKeys(c *check.C) {
namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = app.GetMachineByID(0)
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
oldNodeKey := key.NewNode()
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()),
DiscoKey: "faa",
Hostname: "testmachine",
NamespaceID: namespace.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
app.db.Save(&machine)
_, err = app.GetMachineByNodeKeys(nodeKey.Public(), oldNodeKey.Public())
c.Assert(err, check.IsNil)
}
func (s *Suite) TestDeleteMachine(c *check.C) {
namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)

125
noise.go Normal file
View File

@@ -0,0 +1,125 @@
package headscale
import (
"encoding/base64"
"net/http"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"tailscale.com/control/controlbase"
"tailscale.com/net/netutil"
)
const (
errWrongConnectionUpgrade = Error("wrong connection upgrade")
errCannotHijack = Error("cannot hijack connection")
errNoiseHandshakeFailed = Error("noise handshake failed")
)
const (
// ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade.
ts2021UpgradePath = "/ts2021"
// upgradeHeader is the value of the Upgrade HTTP header used to
// indicate the Tailscale control protocol.
upgradeHeaderValue = "tailscale-control-protocol"
// handshakeHeaderName is the HTTP request header that can
// optionally contain base64-encoded initial handshake
// payload, to save an RTT.
handshakeHeaderName = "X-Tailscale-Handshake"
)
// NoiseUpgradeHandler is to upgrade the connection and hijack the net.Conn
// in order to use the Noise-based TS2021 protocol. Listens in /ts2021.
func (h *Headscale) NoiseUpgradeHandler(ctx *gin.Context) {
log.Trace().Caller().Msgf("Noise upgrade handler for client %s", ctx.ClientIP())
// Under normal circumpstances, we should be able to use the controlhttp.AcceptHTTP()
// function to do this - kindly left there by the Tailscale authors for us to use.
// (https://github.com/tailscale/tailscale/blob/main/control/controlhttp/server.go)
//
// However, Gin seems to be doing something funny/different with its writer (see AcceptHTTP code).
// This causes problems when the upgrade headers are sent in AcceptHTTP.
// So have getNoiseConnection() that is essentially an AcceptHTTP but using the native Gin methods.
noiseConn, err := h.getNoiseConnection(ctx)
if err != nil {
log.Error().Err(err).Msg("noise upgrade failed")
ctx.AbortWithError(http.StatusInternalServerError, err)
return
}
server := http.Server{}
server.Handler = h2c.NewHandler(h.noiseMux, &http2.Server{})
server.Serve(netutil.NewOneConnListener(noiseConn, nil))
}
// getNoiseConnection is basically AcceptHTTP from tailscale, but more _alla_ Gin
// TODO(juan): Figure out why we need to do this at all.
func (h *Headscale) getNoiseConnection(ctx *gin.Context) (*controlbase.Conn, error) {
next := ctx.GetHeader("Upgrade")
if next == "" {
ctx.String(http.StatusBadRequest, "missing next protocol")
return nil, errWrongConnectionUpgrade
}
if next != upgradeHeaderValue {
ctx.String(http.StatusBadRequest, "unknown next protocol")
return nil, errWrongConnectionUpgrade
}
initB64 := ctx.GetHeader(handshakeHeaderName)
if initB64 == "" {
ctx.String(http.StatusBadRequest, "missing Tailscale handshake header")
return nil, errWrongConnectionUpgrade
}
init, err := base64.StdEncoding.DecodeString(initB64)
if err != nil {
ctx.String(http.StatusBadRequest, "invalid tailscale handshake header")
return nil, errWrongConnectionUpgrade
}
hijacker, ok := ctx.Writer.(http.Hijacker)
if !ok {
log.Error().Caller().Err(err).Msgf("Hijack failed")
ctx.String(http.StatusInternalServerError, "HTTP does not support general TCP support")
return nil, errCannotHijack
}
// This is what changes from the original AcceptHTTP() function.
ctx.Header("Upgrade", upgradeHeaderValue)
ctx.Header("Connection", "upgrade")
ctx.Status(http.StatusSwitchingProtocols)
ctx.Writer.WriteHeaderNow()
// end
netConn, conn, err := hijacker.Hijack()
if err != nil {
log.Error().Caller().Err(err).Msgf("Hijack failed")
ctx.String(http.StatusInternalServerError, "HTTP does not support general TCP support")
return nil, errCannotHijack
}
if err := conn.Flush(); err != nil {
netConn.Close()
return nil, errCannotHijack
}
netConn = netutil.NewDrainBufConn(netConn, conn.Reader)
nc, err := controlbase.Server(ctx.Request.Context(), netConn, *h.noisePrivateKey, init)
if err != nil {
netConn.Close()
return nil, errNoiseHandshakeFailed
}
return nc, nil
}

389
noise_api.go Normal file
View File

@@ -0,0 +1,389 @@
package headscale
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
)
func (h *Headscale) NoiseRegistrationHandler(
w http.ResponseWriter,
r *http.Request,
) {
log.Trace().Caller().Msgf("Noise registration handler for client %s", r.RemoteAddr)
if r.Method != http.MethodPost {
http.Error(w, "Wrong method", http.StatusMethodNotAllowed)
return
}
body, _ := io.ReadAll(r.Body)
req := tailcfg.RegisterRequest{}
if err := json.Unmarshal(body, &req); err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse RegisterRequest")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
http.Error(w, "Internal error", http.StatusInternalServerError)
return
}
log.Info().Caller().
Str("nodekey", req.NodeKey.ShortString()).
Str("oldnodekey", req.OldNodeKey.ShortString()).Msg("Nodekys!")
now := time.Now().UTC()
machine, err := h.GetMachineByNodeKeys(req.NodeKey, req.OldNodeKey)
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine via Noise")
// If the machine has AuthKey set, handle registration via PreAuthKeys
if req.Auth.AuthKey != "" {
h.handleNoiseAuthKey(w, r, req)
return
}
givenName, err := h.GenerateGivenName(req.Hostinfo.Hostname)
if err != nil {
log.Error().
Caller().
Str("func", "RegistrationHandler").
Str("hostinfo.name", req.Hostinfo.Hostname).
Err(err)
return
}
// The machine did not have a key to authenticate, which means
// that we rely on a method that calls back some how (OpenID or CLI)
// We create the machine and then keep it around until a callback
// happens
newMachine := Machine{
MachineKey: "",
Hostname: req.Hostinfo.Hostname,
GivenName: givenName,
NodeKey: NodePublicKeyStripPrefix(req.NodeKey),
LastSeen: &now,
Expiry: &time.Time{},
}
if !req.Expiry.IsZero() {
log.Trace().
Caller().
Str("machine", req.Hostinfo.Hostname).
Time("expiry", req.Expiry).
Msg("Non-zero expiry time requested")
newMachine.Expiry = &req.Expiry
}
h.registrationCache.Set(
NodePublicKeyStripPrefix(req.NodeKey),
newMachine,
registerCacheExpiration,
)
h.handleNoiseMachineRegistrationNew(w, r, req)
return
}
// The machine is already registered, so we need to pass through reauth or key update.
if machine != nil {
// If the NodeKey stored in headscale is the same as the key presented in a registration
// request, then we have a node that is either:
// - Trying to log out (sending a expiry in the past)
// - A valid, registered machine, looking for the node map
// - Expired machine wanting to reauthenticate
if machine.NodeKey == NodePublicKeyStripPrefix(req.NodeKey) {
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) {
h.handleNoiseNodeLogOut(w, r, *machine)
return
}
// If machine is not expired, and is register, we have a already accepted this machine,
// let it proceed with a valid registration
if !machine.isExpired() {
h.handleNoiseNodeValidRegistration(w, r, *machine)
return
}
}
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
if machine.NodeKey == NodePublicKeyStripPrefix(req.OldNodeKey) &&
!machine.isExpired() {
h.handleNoiseNodeRefreshKey(w, r, req, *machine)
return
}
// The node has expired
h.handleNoiseNodeExpired(w, r, req, *machine)
return
}
}
func (h *Headscale) handleNoiseAuthKey(
w http.ResponseWriter,
r *http.Request,
registerRequest tailcfg.RegisterRequest,
) {
log.Debug().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Msgf("Processing auth key for %s over Noise", registerRequest.Hostinfo.Hostname)
resp := tailcfg.RegisterResponse{}
pak, err := h.checkKeyValidity(registerRequest.Auth.AuthKey)
if err != nil {
log.Error().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Err(err).
Msg("Failed authentication via AuthKey")
resp.MachineAuthorized = false
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(resp)
log.Error().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Failed authentication via AuthKey over Noise")
if pak != nil {
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc()
} else {
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", "unknown").Inc()
}
return
}
log.Debug().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Authentication key was valid, proceeding to acquire IP addresses")
nodeKey := NodePublicKeyStripPrefix(registerRequest.NodeKey)
// retrieve machine information if it exist
// The error is not important, because if it does not
// exist, then this is a new machine and we will move
// on to registration.
machine, _ := h.GetMachineByNodeKeys(registerRequest.NodeKey, registerRequest.OldNodeKey)
if machine != nil {
log.Trace().
Caller().
Str("machine", machine.Hostname).
Msg("machine already registered, refreshing with new auth key")
machine.NodeKey = nodeKey
machine.AuthKeyID = uint(pak.ID)
h.RefreshMachine(machine, registerRequest.Expiry)
} else {
now := time.Now().UTC()
givenName, err := h.GenerateGivenName(registerRequest.Hostinfo.Hostname)
if err != nil {
log.Error().
Caller().
Str("func", "RegistrationHandler").
Str("hostinfo.name", registerRequest.Hostinfo.Hostname).
Err(err)
return
}
machineToRegister := Machine{
Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName,
NamespaceID: pak.Namespace.ID,
MachineKey: "",
RegisterMethod: RegisterMethodAuthKey,
Expiry: &registerRequest.Expiry,
NodeKey: nodeKey,
LastSeen: &now,
AuthKeyID: uint(pak.ID),
}
machine, err = h.RegisterMachine(
machineToRegister,
)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("could not register machine")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name).
Inc()
http.Error(w, "Internal error", http.StatusInternalServerError)
return
}
}
h.UsePreAuthKey(pak)
resp.MachineAuthorized = true
resp.User = *pak.Namespace.toUser()
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.Namespace.Name).
Inc()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(resp)
log.Info().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Str("ips", strings.Join(machine.IPAddresses.ToStringSlice(), ", ")).
Msg("Successfully authenticated via AuthKey on Noise")
}
func (h *Headscale) handleNoiseNodeValidRegistration(
w http.ResponseWriter,
r *http.Request,
machine Machine,
) {
resp := tailcfg.RegisterResponse{}
// The machine registration is valid, respond with redirect to /map
log.Debug().
Str("machine", machine.Hostname).
Msg("Client is registered and we have the current NodeKey. All clear to /map")
resp.AuthURL = ""
resp.MachineAuthorized = true
resp.User = *machine.Namespace.toUser()
resp.Login = *machine.Namespace.toLogin()
machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name).
Inc()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(resp)
}
func (h *Headscale) handleNoiseMachineRegistrationNew(
w http.ResponseWriter,
r *http.Request,
registerRequest tailcfg.RegisterRequest,
) {
resp := tailcfg.RegisterResponse{}
// The machine registration is new, redirect the client to the registration URL
log.Debug().
Str("machine", registerRequest.Hostinfo.Hostname).
Msg("The node is sending us a new NodeKey, sending auth url")
if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf(
"%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
NodePublicKeyStripPrefix(registerRequest.NodeKey),
)
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), NodePublicKeyStripPrefix(registerRequest.NodeKey))
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(resp)
}
func (h *Headscale) handleNoiseNodeLogOut(
w http.ResponseWriter,
r *http.Request,
machine Machine,
) {
resp := tailcfg.RegisterResponse{}
log.Info().
Str("machine", machine.Hostname).
Msg("Client requested logout")
h.ExpireMachine(&machine)
resp.AuthURL = ""
resp.MachineAuthorized = false
resp.User = *machine.Namespace.toUser()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(resp)
}
func (h *Headscale) handleNoiseNodeRefreshKey(
w http.ResponseWriter,
r *http.Request,
registerRequest tailcfg.RegisterRequest,
machine Machine,
) {
resp := tailcfg.RegisterResponse{}
log.Debug().
Str("machine", machine.Hostname).
Msg("We have the OldNodeKey in the database. This is a key refresh")
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
h.db.Save(&machine)
resp.AuthURL = ""
resp.User = *machine.Namespace.toUser()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(resp)
}
func (h *Headscale) handleNoiseNodeExpired(
w http.ResponseWriter,
r *http.Request,
registerRequest tailcfg.RegisterRequest,
machine Machine,
) {
resp := tailcfg.RegisterResponse{}
// The client has registered before, but has expired
log.Debug().
Caller().
Str("machine", machine.Hostname).
Msg("Machine registration has expired. Sending a authurl to register")
if registerRequest.Auth.AuthKey != "" {
h.handleNoiseAuthKey(w, r, registerRequest)
return
}
if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), machine.NodeKey)
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), machine.NodeKey)
}
machineRegistrations.WithLabelValues("reauth", "web", "success", machine.Namespace.Name).
Inc()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(resp)
}

737
noise_poll.go Normal file
View File

@@ -0,0 +1,737 @@
package headscale
import (
"context"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"
"github.com/klauspost/compress/zstd"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
// NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
//
// The clients POST stuff like HostInfo and their Endpoints here, but
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) NoisePollNetMapHandler(
w http.ResponseWriter,
r *http.Request,
) {
log.Trace().
Str("handler", "NoisePollNetMap").
Msg("PollNetMapHandler called")
body, _ := io.ReadAll(r.Body)
req := tailcfg.MapRequest{}
if err := json.Unmarshal(body, &req); err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse MapRequest")
http.Error(w, "Internal error", http.StatusInternalServerError)
return
}
machine, err := h.GetMachineByNodeKeys(req.NodeKey, key.NodePublic{})
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "NoisePollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", req.NodeKey.String())
http.Error(w, "Internal error", http.StatusNotFound)
return
}
log.Error().
Str("handler", "NoisePollNetMap").
Msgf("Failed to fetch machine from the database with node key: %s", req.NodeKey.String())
http.Error(w, "Internal error", http.StatusInternalServerError)
return
}
log.Trace().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Msg("Found machine in database")
machine.Hostname = req.Hostinfo.Hostname
machine.HostInfo = HostInfo(*req.Hostinfo)
machine.DiscoKey = DiscoPublicKeyStripPrefix(req.DiscoKey)
now := time.Now().UTC()
// update ACLRules with peer informations (to update server tags if necessary)
if h.aclPolicy != nil {
err = h.UpdateACLRules()
if err != nil {
log.Error().
Caller().
Str("func", "handleAuthKey").
Str("machine", machine.Hostname).
Err(err)
}
}
// From Tailscale client:
//
// ReadOnly is whether the client just wants to fetch the MapResponse,
// without updating their Endpoints. The Endpoints field will be ignored and
// LastSeen will not be updated and peers will not be notified of changes.
//
// The intended use is for clients to discover the DERP map at start-up
// before their first real endpoint update.
if !req.ReadOnly {
machine.Endpoints = req.Endpoints
machine.LastSeen = &now
}
if err := h.db.Updates(machine).Error; err != nil {
if err != nil {
log.Error().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Err(err).
Msg("Failed to persist/update machine in the database")
http.Error(w, "Internal error", http.StatusInternalServerError)
return
}
}
resp, err := h.getNoiseMapResponse(req, machine)
if err != nil {
log.Error().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Err(err).
Msg("Failed to get Map response")
http.Error(w, "Internal error", http.StatusInternalServerError)
return
}
// We update our peers if the client is not sending ReadOnly in the MapRequest
// so we don't distribute its initial request (it comes with
// empty endpoints to peers)
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
log.Debug().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Bool("readOnly", req.ReadOnly).
Bool("omitPeers", req.OmitPeers).
Bool("stream", req.Stream).
Msg("Noise client map request processed")
if req.ReadOnly {
log.Info().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Msg("Client is starting up. Probably interested in a DERP map")
// w.Header().Set("Content-Type", "application/json")
// w.WriteHeader(http.StatusOK)
_, err = w.Write(resp)
if err != nil {
log.Warn().Msgf("Could not send JSON response: %s", err)
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
log.Info().Msgf("Noise client map response sent for %s (len %d)", machine.Hostname, len(resp))
return
}
// There has been an update to _any_ of the nodes that the other nodes would
// need to know about
h.setLastStateChangeToNow(machine.Namespace.Name)
// The request is not ReadOnly, so we need to set up channels for updating
// peers via longpoll
// Only create update channel if it has not been created
log.Trace().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Msg("Loading or creating update channel")
const chanSize = 8
updateChan := make(chan struct{}, chanSize)
pollDataChan := make(chan []byte, chanSize)
defer closeChanWithLog(pollDataChan, machine.Hostname, "pollDataChan")
keepAliveChan := make(chan []byte)
if req.OmitPeers && !req.Stream {
log.Info().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Msg("Client sent endpoint update and is ok with a response without peer list")
w.Write(resp)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
// It sounds like we should update the nodes when we have received a endpoint update
// even tho the comments in the tailscale code dont explicitly say so.
updateRequestsFromNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "endpoint-update").
Inc()
updateChan <- struct{}{}
return
} else if req.OmitPeers && req.Stream {
log.Warn().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Msg("Ignoring request, don't know how to handle it")
http.Error(w, "Internal error", http.StatusBadRequest)
return
}
log.Info().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Msg("Client is ready to access the tailnet")
log.Info().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Msg("Sending initial map")
pollDataChan <- resp
log.Info().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Msg("Notifying peers")
updateRequestsFromNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "full-update").
Inc()
updateChan <- struct{}{}
h.NoisePollNetMapStream(
w,
r,
machine,
req,
pollDataChan,
keepAliveChan,
updateChan,
)
log.Trace().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Msg("Finished stream, closing PollNetMap session")
}
// PollNetMapStream takes care of /machine/:id/map
// stream logic, ensuring we communicate updates and data
// to the connected clients.
func (h *Headscale) NoisePollNetMapStream(
w http.ResponseWriter,
r *http.Request,
machine *Machine,
mapRequest tailcfg.MapRequest,
pollDataChan chan []byte,
keepAliveChan chan []byte,
updateChan chan struct{},
) {
ctx := context.WithValue(context.Background(), machineNameContextKey, machine.Hostname)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
go h.noiseScheduledPollWorker(
ctx,
updateChan,
keepAliveChan,
mapRequest,
machine,
)
for {
log.Trace().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Msg("Waiting for data to stream...")
log.Trace().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan)
select {
case data := <-pollDataChan:
log.Trace().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Sending data received via pollData channel")
_, err := w.Write(data)
if err != nil {
log.Error().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "pollData").
Err(err).
Msg("Cannot write data")
break
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
log.Trace().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Data from pollData channel written successfully")
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachineFromDatabase(machine)
if err != nil {
log.Error().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "pollData").
Err(err).
Msg("Cannot update machine from database")
// client has been removed from database
// since the stream opened, terminate connection.
break
}
now := time.Now().UTC()
machine.LastSeen = &now
lastStateUpdate.WithLabelValues(machine.Namespace.Name, machine.Hostname).
Set(float64(now.Unix()))
machine.LastSuccessfulUpdate = &now
err = h.TouchMachine(machine)
if err != nil {
log.Error().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "pollData").
Err(err).
Msg("Cannot update machine LastSuccessfulUpdate")
} else {
log.Trace().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Machine entry in database updated successfully after sending pollData")
}
break
case data := <-keepAliveChan:
log.Trace().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Sending keep alive message")
_, err := w.Write(data)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
if err != nil {
log.Error().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Err(err).
Msg("Cannot write keep alive message")
break
}
log.Trace().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Keep alive sent successfully")
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachineFromDatabase(machine)
if err != nil {
log.Error().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Err(err).
Msg("Cannot update machine from database")
// client has been removed from database
// since the stream opened, terminate connection.
break
}
now := time.Now().UTC()
machine.LastSeen = &now
err = h.TouchMachine(machine)
if err != nil {
log.Error().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Err(err).
Msg("Cannot update machine LastSeen")
} else {
log.Trace().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Machine updated successfully after sending keep alive")
}
break
case <-updateChan:
log.Trace().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "update").
Msg("Received a request for update")
updateRequestsReceivedOnChannel.WithLabelValues(machine.Namespace.Name, machine.Hostname).
Inc()
if h.isOutdated(machine) {
var lastUpdate time.Time
if machine.LastSuccessfulUpdate != nil {
lastUpdate = *machine.LastSuccessfulUpdate
}
log.Debug().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Time("last_successful_update", lastUpdate).
Time("last_state_change", h.getLastStateChange(machine.Namespace.Name)).
Msgf("There has been updates since the last successful update to %s", machine.Hostname)
data, err := h.getNoiseMapResponse(mapRequest, machine)
if err != nil {
log.Error().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "update").
Err(err).
Msg("Could not get the map update")
}
_, err = w.Write(data)
if err != nil {
log.Error().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "update").
Err(err).
Msg("Could not write the map response")
updateRequestsSentToNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "failed").
Inc()
break
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
log.Trace().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "update").
Msg("Updated Map has been sent")
updateRequestsSentToNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "success").
Inc()
// Keep track of the last successful update,
// we sometimes end in a state were the update
// is not picked up by a client and we use this
// to determine if we should "force" an update.
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachineFromDatabase(machine)
if err != nil {
log.Error().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "update").
Err(err).
Msg("Cannot update machine from database")
// client has been removed from database
// since the stream opened, terminate connection.
break
}
now := time.Now().UTC()
lastStateUpdate.WithLabelValues(machine.Namespace.Name, machine.Hostname).
Set(float64(now.Unix()))
machine.LastSuccessfulUpdate = &now
err = h.TouchMachine(machine)
if err != nil {
log.Error().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "update").
Err(err).
Msg("Cannot update machine LastSuccessfulUpdate")
}
} else {
var lastUpdate time.Time
if machine.LastSuccessfulUpdate != nil {
lastUpdate = *machine.LastSuccessfulUpdate
}
log.Trace().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Time("last_successful_update", lastUpdate).
Time("last_state_change", h.getLastStateChange(machine.Namespace.Name)).
Msgf("%s is up to date", machine.Hostname)
}
break
case <-ctx.Done():
log.Info().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Msg("The client has closed the connection")
// TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err := h.UpdateMachineFromDatabase(machine)
if err != nil {
log.Error().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "Done").
Err(err).
Msg("Cannot update machine from database")
// client has been removed from database
// since the stream opened, terminate connection.
break
}
now := time.Now().UTC()
machine.LastSeen = &now
err = h.TouchMachine(machine)
if err != nil {
log.Error().
Str("handler", "NoisePollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "Done").
Err(err).
Msg("Cannot update machine LastSeen")
}
break
}
}
}
func (h *Headscale) noiseScheduledPollWorker(
ctx context.Context,
updateChan chan struct{},
keepAliveChan chan []byte,
mapRequest tailcfg.MapRequest,
machine *Machine,
) {
keepAliveTicker := time.NewTicker(keepAliveInterval)
updateCheckerTicker := time.NewTicker(updateCheckInterval)
defer closeChanWithLog(
updateChan,
fmt.Sprint(ctx.Value(machineNameContextKey)),
"updateChan",
)
defer closeChanWithLog(
keepAliveChan,
fmt.Sprint(ctx.Value(machineNameContextKey)),
"updateChan",
)
for {
select {
case <-ctx.Done():
return
case <-keepAliveTicker.C:
data, err := h.getNoiseMapKeepAliveResponse(mapRequest)
if err != nil {
log.Error().
Str("func", "keepAlive").
Err(err).
Msg("Error generating the keep alive msg")
return
}
log.Debug().
Str("func", "keepAlive").
Str("machine", machine.Hostname).
Msg("Sending keepalive")
keepAliveChan <- data
case <-updateCheckerTicker.C:
log.Debug().
Str("func", "scheduledPollWorker").
Str("machine", machine.Hostname).
Msg("Sending update request")
updateRequestsFromNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "scheduled-update").
Inc()
updateChan <- struct{}{}
}
}
}
func (h *Headscale) getNoiseMapKeepAliveResponse(req tailcfg.MapRequest) ([]byte, error) {
resp := tailcfg.MapResponse{
KeepAlive: true,
}
// The TS2021 protocol does not rely anymore on the machine key to
// encrypt in a NaCl box the map response. We just send it back
// unencrypted via the encrypted Noise channel.
// declare the incoming size on the first 4 bytes
respBody, err := json.Marshal(resp)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot marshal map response")
}
var srcCompressed []byte
if req.Compress == "zstd" {
encoder, _ := zstd.NewWriter(nil)
srcCompressed = encoder.EncodeAll(respBody, nil)
} else {
srcCompressed = respBody
}
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(srcCompressed)))
data = append(data, srcCompressed...)
return data, nil
}
func (h *Headscale) getNoiseMapResponse(
req tailcfg.MapRequest,
machine *Machine,
) ([]byte, error) {
log.Trace().
Str("func", "getNoiseMapResponse").
Str("machine", req.Hostinfo.Hostname).
Msg("Creating Map response")
node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
if err != nil {
log.Error().
Caller().
Str("func", "getNoiseMapResponse").
Err(err).
Msg("Cannot convert to node")
return nil, err
}
peers, err := h.getValidPeers(machine)
if err != nil {
log.Error().
Caller().
Str("func", "getNoiseMapResponse").
Err(err).
Msg("Cannot fetch peers")
return nil, err
}
profiles := getMapResponseUserProfiles(*machine, peers)
nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
if err != nil {
log.Error().
Caller().
Str("func", "getNoiseMapResponse").
Err(err).
Msg("Failed to convert peers to Tailscale nodes")
return nil, err
}
dnsConfig := getMapResponseDNSConfig(
h.cfg.DNSConfig,
h.cfg.BaseDomain,
*machine,
peers,
)
resp := tailcfg.MapResponse{
KeepAlive: false,
Node: node,
Peers: nodePeers,
DNSConfig: dnsConfig,
Domain: h.cfg.BaseDomain,
PacketFilter: h.aclRules,
DERPMap: h.DERPMap,
UserProfiles: profiles,
Debug: &tailcfg.Debug{
DisableLogTail: !h.cfg.LogTail.Enabled,
RandomizeClientPort: h.cfg.RandomizeClientPort,
},
}
log.Trace().
Str("func", "getNoiseMapResponse").
Str("machine", req.Hostinfo.Hostname).
Msgf("Generated map response: %s", tailMapResponseToString(resp))
// The TS2021 protocol does not rely anymore on the machine key to
// encrypt in a NaCl box the map response. We just send it back
// unencrypted via the encrypted Noise channel.
// declare the incoming size on the first 4 bytes
respBody, err := json.Marshal(resp)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot marshal map response")
}
var srcCompressed []byte
if req.Compress == "zstd" {
encoder, _ := zstd.NewWriter(nil)
srcCompressed = encoder.EncodeAll(respBody, nil)
} else {
srcCompressed = respBody
}
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(srcCompressed)))
data = append(data, srcCompressed...)
return data, nil
}

36
oidc.go
View File

@@ -62,10 +62,10 @@ func (h *Headscale) initOIDC() error {
// RegisterOIDC redirects to the OIDC provider for authentication
// Puts machine key in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:mKey.
// Listens in /oidc/register/:nKey.
func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
machineKeyStr := ctx.Param("mkey")
if machineKeyStr == "" {
nodeKeyStr := ctx.Param("nkey")
if nodeKeyStr == "" {
ctx.String(http.StatusBadRequest, "Wrong params")
return
@@ -73,7 +73,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
log.Trace().
Caller().
Str("machine_key", machineKeyStr).
Str("node_key", nodeKeyStr).
Msg("Received oidc register call")
randomBlob := make([]byte, randomByteSize)
@@ -89,7 +89,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
stateStr := hex.EncodeToString(randomBlob)[:32]
// place the machine key into the state cache, so it can be retrieved later
h.registrationCache.Set(stateStr, machineKeyStr, registerCacheExpiration)
h.registrationCache.Set(stateStr, nodeKeyStr, registerCacheExpiration)
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
@@ -217,10 +217,10 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return
}
// retrieve machinekey from state cache
machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
// retrieve nodekey from state cache
nodeKeyIf, nodeKeyFound := h.registrationCache.Get(state)
if !machineKeyFound {
if !nodeKeyFound {
log.Error().
Msg("requested machine state key expired before authorisation completed")
ctx.String(http.StatusBadRequest, "state has expired")
@@ -228,22 +228,22 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return
}
machineKeyFromCache, machineKeyOK := machineKeyIf.(string)
nodeKeyFromCache, nodeKeyOK := nodeKeyIf.(string)
var machineKey key.MachinePublic
err = machineKey.UnmarshalText(
[]byte(MachinePublicKeyEnsurePrefix(machineKeyFromCache)),
var nodeKey key.NodePublic
err = nodeKey.UnmarshalText(
[]byte(MachinePublicKeyEnsurePrefix(nodeKeyFromCache)),
)
if err != nil {
log.Error().
Msg("could not parse machine public key")
Msg("could not parse node public key")
ctx.String(http.StatusBadRequest, "could not parse public key")
return
}
if !machineKeyOK {
log.Error().Msg("could not get machine key from cache")
if !nodeKeyOK {
log.Error().Msg("could not get node key from cache")
ctx.String(
http.StatusInternalServerError,
"could not get machine key from cache",
@@ -256,7 +256,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
// The error is not important, because if it does not
// exist, then this is a new machine and we will move
// on to registration.
machine, _ := h.GetMachineByMachineKey(machineKey)
machine, _ := h.GetMachineByNodeKeys(nodeKey, key.NodePublic{})
if machine != nil {
log.Trace().
@@ -335,10 +335,10 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return
}
machineKeyStr := MachinePublicKeyStripPrefix(machineKey)
nodeKeyStr := NodePublicKeyStripPrefix(nodeKey)
_, err = h.RegisterMachineFromAuthCallback(
machineKeyStr,
nodeKeyStr,
namespace.Name,
RegisterMethodOIDC,
)