Replace machine key with node key in preparation for Noise in auth related stuff

This commit is contained in:
Juan Font Alonso 2022-08-10 15:35:26 +02:00
parent e950b3be29
commit e91c378bd4
4 changed files with 33 additions and 33 deletions

8
api.go
View File

@ -112,8 +112,8 @@ func (h *Headscale) RegisterWebAPI(
writer http.ResponseWriter, writer http.ResponseWriter,
req *http.Request, req *http.Request,
) { ) {
machineKeyStr := req.URL.Query().Get("key") nodeKeyStr := req.URL.Query().Get("key")
if machineKeyStr == "" { if nodeKeyStr == "" {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Wrong params")) _, err := writer.Write([]byte("Wrong params"))
@ -129,7 +129,7 @@ func (h *Headscale) RegisterWebAPI(
var content bytes.Buffer var content bytes.Buffer
if err := registerWebAPITemplate.Execute(&content, registerWebAPITemplateConfig{ if err := registerWebAPITemplate.Execute(&content, registerWebAPITemplateConfig{
Key: machineKeyStr, Key: nodeKeyStr,
}); err != nil { }); err != nil {
log.Error(). log.Error().
Str("func", "RegisterWebAPI"). Str("func", "RegisterWebAPI").
@ -251,7 +251,7 @@ func (h *Headscale) RegistrationHandler(
} }
h.registrationCache.Set( h.registrationCache.Set(
machineKeyStr, newMachine.NodeKey,
newMachine, newMachine,
registerCacheExpiration, registerCacheExpiration,
) )

View File

@ -108,7 +108,7 @@ var registerNodeCmd = &cobra.Command{
if err != nil { if err != nil {
ErrorOutput( ErrorOutput(
err, err,
fmt.Sprintf("Error getting machine key from flag: %s", err), fmt.Sprintf("Error getting node key from flag: %s", err),
output, output,
) )

View File

@ -159,7 +159,7 @@ func (api headscaleV1APIServer) RegisterMachine(
) (*v1.RegisterMachineResponse, error) { ) (*v1.RegisterMachineResponse, error) {
log.Trace(). log.Trace().
Str("namespace", request.GetNamespace()). Str("namespace", request.GetNamespace()).
Str("machine_key", request.GetKey()). Str("node_key", request.GetKey()).
Msg("Registering machine") Msg("Registering machine")
machine, err := api.h.RegisterMachineFromAuthCallback( machine, err := api.h.RegisterMachineFromAuthCallback(
@ -199,7 +199,7 @@ func (api headscaleV1APIServer) SetTags(
err := validateTag(tag) err := validateTag(tag)
if err != nil { if err != nil {
return &v1.SetTagsResponse{ return &v1.SetTagsResponse{
Machine: nil, Machine: nil,
}, status.Error(codes.InvalidArgument, err.Error()) }, status.Error(codes.InvalidArgument, err.Error())
} }
} }

52
oidc.go
View File

@ -27,7 +27,7 @@ const (
errOIDCAllowedDomains = Error("authenticated principal does not match any allowed domain") errOIDCAllowedDomains = Error("authenticated principal does not match any allowed domain")
errOIDCAllowedUsers = Error("authenticated principal does not match any allowed user") errOIDCAllowedUsers = Error("authenticated principal does not match any allowed user")
errOIDCInvalidMachineState = Error("requested machine state key expired before authorisation completed") errOIDCInvalidMachineState = Error("requested machine state key expired before authorisation completed")
errOIDCMachineKeyMissing = Error("could not get machine key from cache") errOIDCNodeKeyMissing = Error("could not get node key from cache")
) )
type IDTokenClaims struct { type IDTokenClaims struct {
@ -68,26 +68,26 @@ func (h *Headscale) initOIDC() error {
} }
// RegisterOIDC redirects to the OIDC provider for authentication // RegisterOIDC redirects to the OIDC provider for authentication
// Puts machine key in cache so the callback can retrieve it using the oidc state param // Puts node key in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:mKey. // Listens in /oidc/register/:mKey.
func (h *Headscale) RegisterOIDC( func (h *Headscale) RegisterOIDC(
writer http.ResponseWriter, writer http.ResponseWriter,
req *http.Request, req *http.Request,
) { ) {
vars := mux.Vars(req) vars := mux.Vars(req)
machineKeyStr, ok := vars["mkey"] nodeKeyStr, ok := vars["nkey"]
if !ok || machineKeyStr == "" { if !ok || nodeKeyStr == "" {
log.Error(). log.Error().
Caller(). Caller().
Msg("Missing machine key in URL") Msg("Missing node key in URL")
http.Error(writer, "Missing machine key in URL", http.StatusBadRequest) http.Error(writer, "Missing node key in URL", http.StatusBadRequest)
return return
} }
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine_key", machineKeyStr). Str("node_key", nodeKeyStr).
Msg("Received oidc register call") Msg("Received oidc register call")
randomBlob := make([]byte, randomByteSize) randomBlob := make([]byte, randomByteSize)
@ -102,8 +102,8 @@ func (h *Headscale) RegisterOIDC(
stateStr := hex.EncodeToString(randomBlob)[:32] stateStr := hex.EncodeToString(randomBlob)[:32]
// place the machine key into the state cache, so it can be retrieved later // place the node key into the state cache, so it can be retrieved later
h.registrationCache.Set(stateStr, machineKeyStr, registerCacheExpiration) h.registrationCache.Set(stateStr, nodeKeyStr, registerCacheExpiration)
// Add any extra parameter provided in the configuration to the Authorize Endpoint request // Add any extra parameter provided in the configuration to the Authorize Endpoint request
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams)) extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
@ -178,7 +178,7 @@ func (h *Headscale) OIDCCallback(
return return
} }
machineKey, machineExists, err := h.validateMachineForOIDCCallback(writer, state, claims) nodeKey, machineExists, err := h.validateMachineForOIDCCallback(writer, state, claims)
if err != nil || machineExists { if err != nil || machineExists {
return return
} }
@ -196,7 +196,7 @@ func (h *Headscale) OIDCCallback(
return return
} }
if err := h.registerMachineForOIDCCallback(writer, namespace, machineKey); err != nil { if err := h.registerMachineForOIDCCallback(writer, namespace, nodeKey); err != nil {
return return
} }
@ -401,7 +401,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
writer http.ResponseWriter, writer http.ResponseWriter,
state string, state string,
claims *IDTokenClaims, claims *IDTokenClaims,
) (*key.MachinePublic, bool, error) { ) (*key.NodePublic, bool, error) {
// retrieve machinekey from state cache // retrieve machinekey from state cache
machineKeyIf, machineKeyFound := h.registrationCache.Get(state) machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
if !machineKeyFound { if !machineKeyFound {
@ -420,14 +420,14 @@ func (h *Headscale) validateMachineForOIDCCallback(
return nil, false, errOIDCInvalidMachineState return nil, false, errOIDCInvalidMachineState
} }
var machineKey key.MachinePublic var nodeKey key.NodePublic
machineKeyFromCache, machineKeyOK := machineKeyIf.(string) nodeKeyFromCache, nodeKeyOK := machineKeyIf.(string)
err := machineKey.UnmarshalText( err := nodeKey.UnmarshalText(
[]byte(MachinePublicKeyEnsurePrefix(machineKeyFromCache)), []byte(NodePublicKeyEnsurePrefix(nodeKeyFromCache)),
) )
if err != nil { if err != nil {
log.Error(). log.Error().
Msg("could not parse machine public key") Msg("could not parse node public key")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("could not parse public key")) _, werr := writer.Write([]byte("could not parse public key"))
@ -441,11 +441,11 @@ func (h *Headscale) validateMachineForOIDCCallback(
return nil, false, err return nil, false, err
} }
if !machineKeyOK { if !nodeKeyOK {
log.Error().Msg("could not get machine key from cache") log.Error().Msg("could not get node key from cache")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError) writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("could not get machine key from cache")) _, err := writer.Write([]byte("could not get node key from cache"))
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -453,14 +453,14 @@ func (h *Headscale) validateMachineForOIDCCallback(
Msg("Failed to write response") Msg("Failed to write response")
} }
return nil, false, errOIDCMachineKeyMissing return nil, false, errOIDCNodeKeyMissing
} }
// retrieve machine information if it exist // retrieve machine information if it exist
// The error is not important, because if it does not // The error is not important, because if it does not
// exist, then this is a new machine and we will move // exist, then this is a new machine and we will move
// on to registration. // on to registration.
machine, _ := h.GetMachineByMachineKey(machineKey) machine, _ := h.GetMachineByNodeKey(nodeKey)
if machine != nil { if machine != nil {
log.Trace(). log.Trace().
@ -516,7 +516,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
return nil, true, nil return nil, true, nil
} }
return &machineKey, false, nil return &nodeKey, false, nil
} }
func getNamespaceName( func getNamespaceName(
@ -596,12 +596,12 @@ func (h *Headscale) findOrCreateNewNamespaceForOIDCCallback(
func (h *Headscale) registerMachineForOIDCCallback( func (h *Headscale) registerMachineForOIDCCallback(
writer http.ResponseWriter, writer http.ResponseWriter,
namespace *Namespace, namespace *Namespace,
machineKey *key.MachinePublic, nodeKey *key.NodePublic,
) error { ) error {
machineKeyStr := MachinePublicKeyStripPrefix(*machineKey) nodeKeyStr := NodePublicKeyStripPrefix(*nodeKey)
if _, err := h.RegisterMachineFromAuthCallback( if _, err := h.RegisterMachineFromAuthCallback(
machineKeyStr, nodeKeyStr,
namespace.Name, namespace.Name,
RegisterMethodOIDC, RegisterMethodOIDC,
); err != nil { ); err != nil {