package headscale import ( "encoding/binary" "encoding/json" "errors" "fmt" "io" "net/http" "strings" "time" "github.com/gin-gonic/gin" "github.com/klauspost/compress/zstd" "github.com/rs/zerolog/log" "go4.org/mem" "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" ) const ( reservedResponseHeaderSize = 4 RegisterMethodAuthKey = "authKey" RegisterMethodOIDC = "oidc" RegisterMethodCLI = "cli" ErrRegisterMethodCLIDoesNotSupportExpire = Error( "machines registered with CLI does not support expire", ) ) // KeyHandler provides the Headscale pub key // Listens in /key. func (h *Headscale) KeyHandler(ctx *gin.Context) { ctx.Data( http.StatusOK, "text/plain; charset=utf-8", []byte(MachinePublicKeyStripPrefix(*h.publicKey)), ) } // RegisterWebAPI shows a simple message in the browser to point to the CLI // Listens in /register. func (h *Headscale) RegisterWebAPI(ctx *gin.Context) { machineKeyStr := ctx.Query("key") if machineKeyStr == "" { ctx.String(http.StatusBadRequest, "Wrong params") return } ctx.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`

headscale

Run the command below in the headscale server to add this machine to your network:

headscale -n NAMESPACE nodes register --key %s

`, machineKeyStr))) } // RegistrationHandler handles the actual registration process of a machine // Endpoint /machine/:id. func (h *Headscale) RegistrationHandler(ctx *gin.Context) { body, _ := io.ReadAll(ctx.Request.Body) machineKeyStr := ctx.Param("id") machineKey, err := key.ParseMachinePublicUntyped(mem.S(machineKeyStr)) if err != nil { log.Error(). Caller(). Err(err). Msg("Cannot parse machine key") machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() ctx.String(http.StatusInternalServerError, "Sad!") return } req := tailcfg.RegisterRequest{} err = decode(body, &req, &machineKey, h.privateKey) if err != nil { log.Error(). Caller(). Err(err). Msg("Cannot decode message") machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() ctx.String(http.StatusInternalServerError, "Very sad!") return } now := time.Now().UTC() machine, err := h.GetMachineByMachineKey(machineKey) if errors.Is(err, gorm.ErrRecordNotFound) { log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine") newMachine := Machine{ Expiry: &time.Time{}, MachineKey: MachinePublicKeyStripPrefix(machineKey), Name: req.Hostinfo.Hostname, } if err := h.db.Create(&newMachine).Error; err != nil { log.Error(). Caller(). Err(err). Msg("Could not create row") machineRegistrations.WithLabelValues("unknown", "web", "error", machine.Namespace.Name). Inc() return } machine = &newMachine } if machine.Registered { // If the NodeKey stored in headscale is the same as the key presented in a registration // request, then we have a node that is either: // - Trying to log out (sending a expiry in the past) // - A valid, registered machine, looking for the node map // - Expired machine wanting to reauthenticate if machine.NodeKey == req.NodeKey.String() { // The client sends an Expiry in the past if the client is requesting to expire the key (aka logout) // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648 if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) { h.handleMachineLogOut(ctx, machineKey, *machine) return } // If machine is not expired, and is register, we have a already accepted this machine, // let it proceed with a valid registration if !machine.isExpired() { h.handleMachineValidRegistration(ctx, machineKey, *machine) return } } // The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration if machine.NodeKey == req.OldNodeKey.String() && !machine.isExpired() { h.handleMachineRefreshKey(ctx, machineKey, req, *machine) return } // The machine has expired h.handleMachineExpired(ctx, machineKey, req, *machine) return } // If the machine has AuthKey set, handle registration via PreAuthKeys if req.Auth.AuthKey != "" { h.handleAuthKey(ctx, machineKey, req, *machine) return } h.handleMachineRegistrationNew(ctx, machineKey, req, *machine) } func (h *Headscale) getMapResponse( machineKey key.MachinePublic, req tailcfg.MapRequest, machine *Machine, ) ([]byte, error) { log.Trace(). Str("func", "getMapResponse"). Str("machine", req.Hostinfo.Hostname). Msg("Creating Map response") node, err := machine.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true) if err != nil { log.Error(). Caller(). Str("func", "getMapResponse"). Err(err). Msg("Cannot convert to node") return nil, err } peers, err := h.getValidPeers(machine) if err != nil { log.Error(). Caller(). Str("func", "getMapResponse"). Err(err). Msg("Cannot fetch peers") return nil, err } profiles := getMapResponseUserProfiles(*machine, peers) nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true) if err != nil { log.Error(). Caller(). Str("func", "getMapResponse"). Err(err). Msg("Failed to convert peers to Tailscale nodes") return nil, err } dnsConfig := getMapResponseDNSConfig( h.cfg.DNSConfig, h.cfg.BaseDomain, *machine, peers, ) resp := tailcfg.MapResponse{ KeepAlive: false, Node: node, Peers: nodePeers, DNSConfig: dnsConfig, Domain: h.cfg.BaseDomain, PacketFilter: h.aclRules, DERPMap: h.DERPMap, UserProfiles: profiles, } log.Trace(). Str("func", "getMapResponse"). Str("machine", req.Hostinfo.Hostname). // Interface("payload", resp). Msgf("Generated map response: %s", tailMapResponseToString(resp)) var respBody []byte if req.Compress == "zstd" { src, _ := json.Marshal(resp) encoder, _ := zstd.NewWriter(nil) srcCompressed := encoder.EncodeAll(src, nil) respBody = h.privateKey.SealTo(machineKey, srcCompressed) } else { respBody, err = encode(resp, &machineKey, h.privateKey) if err != nil { return nil, err } } // declare the incoming size on the first 4 bytes data := make([]byte, reservedResponseHeaderSize) binary.LittleEndian.PutUint32(data, uint32(len(respBody))) data = append(data, respBody...) return data, nil } func (h *Headscale) getMapKeepAliveResponse( machineKey key.MachinePublic, mapRequest tailcfg.MapRequest, ) ([]byte, error) { mapResponse := tailcfg.MapResponse{ KeepAlive: true, } var respBody []byte var err error if mapRequest.Compress == "zstd" { src, _ := json.Marshal(mapResponse) encoder, _ := zstd.NewWriter(nil) srcCompressed := encoder.EncodeAll(src, nil) respBody = h.privateKey.SealTo(machineKey, srcCompressed) } else { respBody, err = encode(mapResponse, &machineKey, h.privateKey) if err != nil { return nil, err } } data := make([]byte, reservedResponseHeaderSize) binary.LittleEndian.PutUint32(data, uint32(len(respBody))) data = append(data, respBody...) return data, nil } func (h *Headscale) handleMachineLogOut( ctx *gin.Context, machineKey key.MachinePublic, machine Machine, ) { resp := tailcfg.RegisterResponse{} log.Info(). Str("machine", machine.Name). Msg("Client requested logout") h.ExpireMachine(&machine) resp.AuthURL = "" resp.MachineAuthorized = false resp.User = *machine.Namespace.toUser() respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). Caller(). Err(err). Msg("Cannot encode message") ctx.String(http.StatusInternalServerError, "") return } ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) } func (h *Headscale) handleMachineValidRegistration( ctx *gin.Context, machineKey key.MachinePublic, machine Machine, ) { resp := tailcfg.RegisterResponse{} // The machine registration is valid, respond with redirect to /map log.Debug(). Str("machine", machine.Name). Msg("Client is registered and we have the current NodeKey. All clear to /map") resp.AuthURL = "" resp.MachineAuthorized = true resp.User = *machine.Namespace.toUser() resp.Login = *machine.Namespace.toLogin() respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). Caller(). Err(err). Msg("Cannot encode message") machineRegistrations.WithLabelValues("update", "web", "error", machine.Namespace.Name). Inc() ctx.String(http.StatusInternalServerError, "") return } machineRegistrations.WithLabelValues("update", "web", "success", machine.Namespace.Name). Inc() ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) } func (h *Headscale) handleMachineExpired( ctx *gin.Context, machineKey key.MachinePublic, registerRequest tailcfg.RegisterRequest, machine Machine, ) { resp := tailcfg.RegisterResponse{} // The client has registered before, but has expired log.Debug(). Str("machine", machine.Name). Msg("Machine registration has expired. Sending a authurl to register") if registerRequest.Auth.AuthKey != "" { h.handleAuthKey(ctx, machineKey, registerRequest, machine) return } if h.cfg.OIDC.Issuer != "" { resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.String()) } else { resp.AuthURL = fmt.Sprintf("%s/register?key=%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.String()) } respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). Caller(). Err(err). Msg("Cannot encode message") machineRegistrations.WithLabelValues("reauth", "web", "error", machine.Namespace.Name). Inc() ctx.String(http.StatusInternalServerError, "") return } machineRegistrations.WithLabelValues("reauth", "web", "success", machine.Namespace.Name). Inc() ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) } func (h *Headscale) handleMachineRefreshKey( ctx *gin.Context, machineKey key.MachinePublic, registerRequest tailcfg.RegisterRequest, machine Machine, ) { resp := tailcfg.RegisterResponse{} log.Debug(). Str("machine", machine.Name). Msg("We have the OldNodeKey in the database. This is a key refresh") machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey) h.db.Save(&machine) resp.AuthURL = "" resp.User = *machine.Namespace.toUser() respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). Caller(). Err(err). Msg("Cannot encode message") ctx.String(http.StatusInternalServerError, "Extremely sad!") return } ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) } func (h *Headscale) handleMachineRegistrationNew( ctx *gin.Context, machineKey key.MachinePublic, registerRequest tailcfg.RegisterRequest, machine Machine, ) { resp := tailcfg.RegisterResponse{} // The machine registration is new, redirect the client to the registration URL log.Debug(). Str("machine", machine.Name). Msg("The node is sending us a new NodeKey, sending auth url") if h.cfg.OIDC.Issuer != "" { resp.AuthURL = fmt.Sprintf( "%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), machineKey.String(), ) } else { resp.AuthURL = fmt.Sprintf("%s/register?key=%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey)) } if !registerRequest.Expiry.IsZero() { log.Trace(). Caller(). Str("machine", machine.Name). Time("expiry", registerRequest.Expiry). Msg("Non-zero expiry time requested, adding to cache") h.requestedExpiryCache.Set( machineKey.String(), registerRequest.Expiry, requestedExpiryCacheExpiration, ) } machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey) // save the NodeKey h.db.Save(&machine) respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). Caller(). Err(err). Msg("Cannot encode message") ctx.String(http.StatusInternalServerError, "") return } ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) } func (h *Headscale) handleAuthKey( ctx *gin.Context, machineKey key.MachinePublic, registerRequest tailcfg.RegisterRequest, machine Machine, ) { log.Debug(). Str("func", "handleAuthKey"). Str("machine", registerRequest.Hostinfo.Hostname). Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname) resp := tailcfg.RegisterResponse{} pak, err := h.checkKeyValidity(registerRequest.Auth.AuthKey) if err != nil { log.Error(). Caller(). Str("func", "handleAuthKey"). Str("machine", machine.Name). Err(err). Msg("Failed authentication via AuthKey") resp.MachineAuthorized = false respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). Caller(). Str("func", "handleAuthKey"). Str("machine", machine.Name). Err(err). Msg("Cannot encode message") ctx.String(http.StatusInternalServerError, "") machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). Inc() return } ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody) log.Error(). Caller(). Str("func", "handleAuthKey"). Str("machine", machine.Name). Msg("Failed authentication via AuthKey") machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). Inc() return } if machine.isRegistered() { log.Trace(). Caller(). Str("machine", machine.Name). Msg("machine already registered, reauthenticating") h.RefreshMachine(&machine, registerRequest.Expiry) } else { log.Debug(). Str("func", "handleAuthKey"). Str("machine", machine.Name). Msg("Authentication key was valid, proceeding to acquire an IP address") ip, err := h.getAvailableIP() if err != nil { log.Error(). Caller(). Str("func", "handleAuthKey"). Str("machine", machine.Name). Msg("Failed to find an available IP") machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). Inc() return } log.Info(). Str("func", "handleAuthKey"). Str("machine", machine.Name). Str("ip", ip.String()). Msgf("Assigning %s to %s", ip, machine.Name) machine.Expiry = ®isterRequest.Expiry machine.AuthKeyID = uint(pak.ID) machine.IPAddress = ip.String() machine.NamespaceID = pak.NamespaceID machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey) // we update it just in case machine.Registered = true machine.RegisterMethod = RegisterMethodAuthKey h.db.Save(&machine) } pak.Used = true h.db.Save(&pak) resp.MachineAuthorized = true resp.User = *pak.Namespace.toUser() respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). Caller(). Str("func", "handleAuthKey"). Str("machine", machine.Name). Err(err). Msg("Cannot encode message") machineRegistrations.WithLabelValues("new", "authkey", "error", machine.Namespace.Name). Inc() ctx.String(http.StatusInternalServerError, "Extremely sad!") return } machineRegistrations.WithLabelValues("new", "authkey", "success", machine.Namespace.Name). Inc() ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) log.Info(). Str("func", "handleAuthKey"). Str("machine", machine.Name). Str("ip", machine.IPAddress). Msg("Successfully authenticated via AuthKey") }