diff --git a/api.go b/api.go index 9760da0d..a9711e0e 100644 --- a/api.go +++ b/api.go @@ -107,13 +107,14 @@ var registerWebAPITemplate = template.Must( `)) // RegisterWebAPI shows a simple message in the browser to point to the CLI -// Listens in /register. +// Listens in /register/:nkey. func (h *Headscale) RegisterWebAPI( writer http.ResponseWriter, req *http.Request, ) { - nodeKeyStr := req.URL.Query().Get("key") - if nodeKeyStr == "" { + vars := mux.Vars(req) + nodeKeyStr, ok := vars["nkey"] + if !ok || nodeKeyStr == "" { writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusBadRequest) _, err := writer.Write([]byte("Wrong params")) @@ -206,8 +207,6 @@ func (h *Headscale) RegistrationHandler( now := time.Now().UTC() machine, err := h.GetMachineByMachineKey(machineKey) if errors.Is(err, gorm.ErrRecordNotFound) { - log.Info().Str("machine", registerRequest.Hostinfo.Hostname).Msg("New machine") - machineKeyStr := MachinePublicKeyStripPrefix(machineKey) // If the machine has AuthKey set, handle registration via PreAuthKeys @@ -217,6 +216,38 @@ func (h *Headscale) RegistrationHandler( return } + // Check if the node is waiting for interactive login + // + // TODO(juan): We could use this field to improve our protocol implementation, + // and hold the request until the client closes it, or the interactive + // login is completed (i.e., the user registers the machine). + // This is not implemented yet, as it is no strictly required. The only side-effect + // is that the client will hammer headscale with requests until it gets a + // successful RegisterResponse. + if registerRequest.Followup != "" { + if _, ok := h.registrationCache.Get(NodePublicKeyStripPrefix(registerRequest.NodeKey)); ok { + log.Debug(). + Caller(). + Str("machine", registerRequest.Hostinfo.Hostname). + Str("NodeKey", registerRequest.NodeKey.ShortString()). + Str("OldNodeKey", registerRequest.OldNodeKey.ShortString()). + Str("Followup", registerRequest.Followup). + Msg("Machine is waiting for interactive login") + + h.handleMachineRegistrationNew(writer, req, machineKey, registerRequest) + + return + } + } + + log.Info(). + Caller(). + Str("machine", registerRequest.Hostinfo.Hostname). + Str("NodeKey", registerRequest.NodeKey.ShortString()). + Str("OldNodeKey", registerRequest.OldNodeKey.ShortString()). + Str("Followup", registerRequest.Followup). + Msg("New machine not yet in the database") + givenName, err := h.GenerateGivenName(registerRequest.Hostinfo.Hostname) if err != nil { log.Error(). @@ -645,7 +676,7 @@ func (h *Headscale) handleMachineRegistrationNew( // The machine registration is new, redirect the client to the registration URL log.Debug(). Str("machine", registerRequest.Hostinfo.Hostname). - Msg("The node is sending us a new NodeKey, sending auth url") + Msg("The node seems to be new, sending auth url") if h.cfg.OIDC.Issuer != "" { resp.AuthURL = fmt.Sprintf( "%s/oidc/register/%s", @@ -653,8 +684,8 @@ func (h *Headscale) handleMachineRegistrationNew( machineKey.String(), ) } else { - resp.AuthURL = fmt.Sprintf("%s/register?key=%s", - strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey)) + resp.AuthURL = fmt.Sprintf("%s/register/%s", + strings.TrimSuffix(h.cfg.ServerURL, "/"), NodePublicKeyStripPrefix(registerRequest.NodeKey)) } respBody, err := encode(resp, &machineKey, h.privateKey) diff --git a/app.go b/app.go index bd88dedf..60258e6e 100644 --- a/app.go +++ b/app.go @@ -417,7 +417,7 @@ func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *mux.Router { router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet) router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet) - router.HandleFunc("/register", h.RegisterWebAPI).Methods(http.MethodGet) + router.HandleFunc("/register/{nkey}", h.RegisterWebAPI).Methods(http.MethodGet) router.HandleFunc("/machine/{mkey}/map", h.PollNetMapHandler).Methods(http.MethodPost) router.HandleFunc("/machine/{mkey}", h.RegistrationHandler).Methods(http.MethodPost) router.HandleFunc("/oidc/register/{mkey}", h.RegisterOIDC).Methods(http.MethodGet)