TS2021: Use NodeKey for everything, as MachineKey is deprecated in TS2021

This commit is contained in:
Juan Font Alonso 2022-06-12 12:30:56 +02:00
parent b40b4e8d45
commit e8205e8d5a
5 changed files with 76 additions and 30 deletions

4
api.go
View File

@ -546,11 +546,11 @@ func (h *Headscale) handleMachineRegistrationNew(
resp.AuthURL = fmt.Sprintf( resp.AuthURL = fmt.Sprintf(
"%s/oidc/register/%s", "%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), strings.TrimSuffix(h.cfg.ServerURL, "/"),
machineKey.String(), NodePublicKeyStripPrefix(registerRequest.NodeKey),
) )
} else { } else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s", resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey)) strings.TrimSuffix(h.cfg.ServerURL, "/"), NodePublicKeyStripPrefix(registerRequest.NodeKey))
} }
respBody, err := encode(resp, &machineKey, h.privateKey) respBody, err := encode(resp, &machineKey, h.privateKey)

2
app.go
View File

@ -415,7 +415,7 @@ func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *gin.Engine {
router.GET("/register", h.RegisterWebAPI) router.GET("/register", h.RegisterWebAPI)
router.POST("/machine/:id/map", h.PollNetMapHandler) router.POST("/machine/:id/map", h.PollNetMapHandler)
router.POST("/machine/:id", h.RegistrationHandler) router.POST("/machine/:id", h.RegistrationHandler)
router.GET("/oidc/register/:mkey", h.RegisterOIDC) router.GET("/oidc/register/:nkey", h.RegisterOIDC)
router.GET("/oidc/callback", h.OIDCCallback) router.GET("/oidc/callback", h.OIDCCallback)
router.GET("/apple", h.AppleConfigMessage) router.GET("/apple", h.AppleConfigMessage)
router.GET("/apple/:platform", h.ApplePlatformConfig) router.GET("/apple/:platform", h.ApplePlatformConfig)

View File

@ -349,7 +349,7 @@ func (h *Headscale) GetMachineByID(id uint64) (*Machine, error) {
return &m, nil return &m, nil
} }
// GetMachineByMachineKey finds a Machine by ID and returns the Machine struct. // GetMachineByMachineKey finds a Machine by its MachineKey and returns the Machine struct.
func (h *Headscale) GetMachineByMachineKey( func (h *Headscale) GetMachineByMachineKey(
machineKey key.MachinePublic, machineKey key.MachinePublic,
) (*Machine, error) { ) (*Machine, error) {
@ -361,6 +361,19 @@ func (h *Headscale) GetMachineByMachineKey(
return &m, nil return &m, nil
} }
// GetMachineByNodeKeys finds a Machine by its current NodeKey or the old one, and returns the Machine struct.
func (h *Headscale) GetMachineByNodeKeys(
nodeKey key.NodePublic, oldNodeKey key.NodePublic,
) (*Machine, error) {
machine := Machine{}
if result := h.db.Preload("Namespace").First(&machine, "node_key = ? OR node_key = ?",
NodePublicKeyStripPrefix(nodeKey), NodePublicKeyStripPrefix(oldNodeKey)); result.Error != nil {
return nil, result.Error
}
return &machine, nil
}
// UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database // UpdateMachineFromDatabase takes a Machine struct pointer (typically already loaded from database
// and updates it with the latest data from the database. // and updates it with the latest data from the database.
func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error { func (h *Headscale) UpdateMachineFromDatabase(machine *Machine) error {
@ -567,12 +580,15 @@ func (machine Machine) toNode(
} }
var machineKey key.MachinePublic var machineKey key.MachinePublic
if machine.MachineKey != "" {
// MachineKey is only used in the legacy protocol
err = machineKey.UnmarshalText( err = machineKey.UnmarshalText(
[]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)), []byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)),
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse machine public key: %w", err) return nil, fmt.Errorf("failed to parse machine public key: %w", err)
} }
}
var discoKey key.DiscoPublic var discoKey key.DiscoPublic
if machine.DiscoKey != "" { if machine.DiscoKey != "" {
@ -750,11 +766,11 @@ func getTags(
} }
func (h *Headscale) RegisterMachineFromAuthCallback( func (h *Headscale) RegisterMachineFromAuthCallback(
machineKeyStr string, nodeKeyStr string,
namespaceName string, namespaceName string,
registrationMethod string, registrationMethod string,
) (*Machine, error) { ) (*Machine, error) {
if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok { if machineInterface, ok := h.registrationCache.Get(nodeKeyStr); ok {
if registrationMachine, ok := machineInterface.(Machine); ok { if registrationMachine, ok := machineInterface.(Machine); ok {
namespace, err := h.GetNamespace(namespaceName) namespace, err := h.GetNamespace(namespaceName)
if err != nil { if err != nil {
@ -785,7 +801,7 @@ func (h *Headscale) RegisterMachine(machine Machine,
) (*Machine, error) { ) (*Machine, error) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine_key", machine.MachineKey). Str("node_key", machine.NodeKey).
Msg("Registering machine") Msg("Registering machine")
log.Trace(). log.Trace().

View File

@ -11,6 +11,7 @@ import (
"gopkg.in/check.v1" "gopkg.in/check.v1"
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key"
) )
func (s *Suite) TestGetMachine(c *check.C) { func (s *Suite) TestGetMachine(c *check.C) {
@ -65,6 +66,35 @@ func (s *Suite) TestGetMachineByID(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
} }
func (s *Suite) TestGetMachineByNodeKeys(c *check.C) {
namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(namespace.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = app.GetMachineByID(0)
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
oldNodeKey := key.NewNode()
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()),
DiscoKey: "faa",
Hostname: "testmachine",
NamespaceID: namespace.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
app.db.Save(&machine)
_, err = app.GetMachineByNodeKeys(nodeKey.Public(), oldNodeKey.Public())
c.Assert(err, check.IsNil)
}
func (s *Suite) TestDeleteMachine(c *check.C) { func (s *Suite) TestDeleteMachine(c *check.C) {
namespace, err := app.CreateNamespace("test") namespace, err := app.CreateNamespace("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)

36
oidc.go
View File

@ -62,10 +62,10 @@ func (h *Headscale) initOIDC() error {
// RegisterOIDC redirects to the OIDC provider for authentication // RegisterOIDC redirects to the OIDC provider for authentication
// Puts machine key in cache so the callback can retrieve it using the oidc state param // Puts machine key in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:mKey. // Listens in /oidc/register/:nKey.
func (h *Headscale) RegisterOIDC(ctx *gin.Context) { func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
machineKeyStr := ctx.Param("mkey") nodeKeyStr := ctx.Param("nkey")
if machineKeyStr == "" { if nodeKeyStr == "" {
ctx.String(http.StatusBadRequest, "Wrong params") ctx.String(http.StatusBadRequest, "Wrong params")
return return
@ -73,7 +73,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine_key", machineKeyStr). Str("node_key", nodeKeyStr).
Msg("Received oidc register call") Msg("Received oidc register call")
randomBlob := make([]byte, randomByteSize) randomBlob := make([]byte, randomByteSize)
@ -89,7 +89,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
stateStr := hex.EncodeToString(randomBlob)[:32] stateStr := hex.EncodeToString(randomBlob)[:32]
// place the machine key into the state cache, so it can be retrieved later // place the machine key into the state cache, so it can be retrieved later
h.registrationCache.Set(stateStr, machineKeyStr, registerCacheExpiration) h.registrationCache.Set(stateStr, nodeKeyStr, registerCacheExpiration)
// Add any extra parameter provided in the configuration to the Authorize Endpoint request // Add any extra parameter provided in the configuration to the Authorize Endpoint request
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams)) extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
@ -217,10 +217,10 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
// retrieve machinekey from state cache // retrieve nodekey from state cache
machineKeyIf, machineKeyFound := h.registrationCache.Get(state) nodeKeyIf, nodeKeyFound := h.registrationCache.Get(state)
if !machineKeyFound { if !nodeKeyFound {
log.Error(). log.Error().
Msg("requested machine state key expired before authorisation completed") Msg("requested machine state key expired before authorisation completed")
ctx.String(http.StatusBadRequest, "state has expired") ctx.String(http.StatusBadRequest, "state has expired")
@ -228,22 +228,22 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
machineKeyFromCache, machineKeyOK := machineKeyIf.(string) nodeKeyFromCache, nodeKeyOK := nodeKeyIf.(string)
var machineKey key.MachinePublic var nodeKey key.NodePublic
err = machineKey.UnmarshalText( err = nodeKey.UnmarshalText(
[]byte(MachinePublicKeyEnsurePrefix(machineKeyFromCache)), []byte(MachinePublicKeyEnsurePrefix(nodeKeyFromCache)),
) )
if err != nil { if err != nil {
log.Error(). log.Error().
Msg("could not parse machine public key") Msg("could not parse node public key")
ctx.String(http.StatusBadRequest, "could not parse public key") ctx.String(http.StatusBadRequest, "could not parse public key")
return return
} }
if !machineKeyOK { if !nodeKeyOK {
log.Error().Msg("could not get machine key from cache") log.Error().Msg("could not get node key from cache")
ctx.String( ctx.String(
http.StatusInternalServerError, http.StatusInternalServerError,
"could not get machine key from cache", "could not get machine key from cache",
@ -256,7 +256,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
// The error is not important, because if it does not // The error is not important, because if it does not
// exist, then this is a new machine and we will move // exist, then this is a new machine and we will move
// on to registration. // on to registration.
machine, _ := h.GetMachineByMachineKey(machineKey) machine, _ := h.GetMachineByNodeKeys(nodeKey, key.NodePublic{})
if machine != nil { if machine != nil {
log.Trace(). log.Trace().
@ -335,10 +335,10 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
machineKeyStr := MachinePublicKeyStripPrefix(machineKey) nodeKeyStr := NodePublicKeyStripPrefix(nodeKey)
_, err = h.RegisterMachineFromAuthCallback( _, err = h.RegisterMachineFromAuthCallback(
machineKeyStr, nodeKeyStr,
namespace.Name, namespace.Name,
RegisterMethodOIDC, RegisterMethodOIDC,
) )