From d0898ecabce313c09647dd1594fa2be53a546e1b Mon Sep 17 00:00:00 2001 From: Juan Font Alonso Date: Sun, 14 Aug 2022 21:15:58 +0200 Subject: [PATCH] Move common parts of the protocol to dedicated file --- api.go | 281 --------------------------------------- protocol_common.go | 66 +++++++++ protocol_common_utils.go | 47 +++++++ 3 files changed, 113 insertions(+), 281 deletions(-) create mode 100644 protocol_common.go create mode 100644 protocol_common_utils.go diff --git a/api.go b/api.go index 5ac0d3d9..38e8e64b 100644 --- a/api.go +++ b/api.go @@ -4,19 +4,15 @@ import ( "bytes" "encoding/binary" "encoding/json" - "errors" "fmt" "html/template" - "io" "net/http" - "strconv" "strings" "time" "github.com/gorilla/mux" "github.com/klauspost/compress/zstd" "github.com/rs/zerolog/log" - "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" ) @@ -77,62 +73,6 @@ func (h *Headscale) HealthHandler( respond(nil) } -// KeyHandler provides the Headscale pub key -// Listens in /key. -func (h *Headscale) KeyHandler( - writer http.ResponseWriter, - req *http.Request, -) { - // New Tailscale clients send a 'v' parameter to indicate the CurrentCapabilityVersion - clientCapabilityStr := req.URL.Query().Get("v") - if clientCapabilityStr != "" { - clientCapabilityVersion, err := strconv.Atoi(clientCapabilityStr) - if err != nil { - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusBadRequest) - _, err := writer.Write([]byte("Wrong params")) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - - return - } - - if clientCapabilityVersion >= NoiseCapabilityVersion { - // Tailscale has a different key for the TS2021 protocol - resp := tailcfg.OverTLSPublicKeyResponse{ - LegacyPublicKey: h.privateKey.Public(), - PublicKey: h.noisePrivateKey.Public(), - } - writer.Header().Set("Content-Type", "application/json") - writer.WriteHeader(http.StatusOK) - err = json.NewEncoder(writer).Encode(resp) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - - return - } - } - - // Old clients don't send a 'v' parameter, so we send the legacy public key - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusOK) - _, err := writer.Write([]byte(MachinePublicKeyStripPrefix(h.privateKey.Public()))) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } -} - type registerWebAPITemplateConfig struct { Key string } @@ -211,191 +151,6 @@ func (h *Headscale) RegisterWebAPI( } } -// RegistrationHandler handles the actual registration process of a machine -// Endpoint /machine/:mkey. -func (h *Headscale) RegistrationHandler( - writer http.ResponseWriter, - req *http.Request, -) { - vars := mux.Vars(req) - machineKeyStr, ok := vars["mkey"] - if !ok || machineKeyStr == "" { - log.Error(). - Str("handler", "RegistrationHandler"). - Msg("No machine ID in request") - http.Error(writer, "No machine ID in request", http.StatusBadRequest) - - return - } - - body, _ := io.ReadAll(req.Body) - - var machineKey key.MachinePublic - err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr))) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot parse machine key") - machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() - http.Error(writer, "Cannot parse machine key", http.StatusBadRequest) - - return - } - registerRequest := tailcfg.RegisterRequest{} - err = decode(body, ®isterRequest, &machineKey, h.privateKey) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot decode message") - machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() - http.Error(writer, "Cannot decode message", http.StatusBadRequest) - - return - } - - now := time.Now().UTC() - machine, err := h.GetMachineByMachineKey(machineKey) - if errors.Is(err, gorm.ErrRecordNotFound) { - machineKeyStr := MachinePublicKeyStripPrefix(machineKey) - - // If the machine has AuthKey set, handle registration via PreAuthKeys - if registerRequest.Auth.AuthKey != "" { - h.handleAuthKey(writer, req, machineKey, registerRequest) - - return - } - - // Check if the node is waiting for interactive login. - // - // TODO(juan): We could use this field to improve our protocol implementation, - // and hold the request until the client closes it, or the interactive - // login is completed (i.e., the user registers the machine). - // This is not implemented yet, as it is no strictly required. The only side-effect - // is that the client will hammer headscale with requests until it gets a - // successful RegisterResponse. - if registerRequest.Followup != "" { - if _, ok := h.registrationCache.Get(NodePublicKeyStripPrefix(registerRequest.NodeKey)); ok { - log.Debug(). - Caller(). - Str("machine", registerRequest.Hostinfo.Hostname). - Str("node_key", registerRequest.NodeKey.ShortString()). - Str("node_key_old", registerRequest.OldNodeKey.ShortString()). - Str("follow_up", registerRequest.Followup). - Msg("Machine is waiting for interactive login") - - ticker := time.NewTicker(registrationHoldoff) - select { - case <-req.Context().Done(): - return - case <-ticker.C: - h.handleMachineRegistrationNew(writer, req, machineKey, registerRequest) - - return - } - } - } - - log.Info(). - Caller(). - Str("machine", registerRequest.Hostinfo.Hostname). - Str("node_key", registerRequest.NodeKey.ShortString()). - Str("node_key_old", registerRequest.OldNodeKey.ShortString()). - Str("follow_up", registerRequest.Followup). - Msg("New machine not yet in the database") - - givenName, err := h.GenerateGivenName(registerRequest.Hostinfo.Hostname) - if err != nil { - log.Error(). - Caller(). - Str("func", "RegistrationHandler"). - Str("hostinfo.name", registerRequest.Hostinfo.Hostname). - Err(err) - - return - } - - // The machine did not have a key to authenticate, which means - // that we rely on a method that calls back some how (OpenID or CLI) - // We create the machine and then keep it around until a callback - // happens - newMachine := Machine{ - MachineKey: machineKeyStr, - Hostname: registerRequest.Hostinfo.Hostname, - GivenName: givenName, - NodeKey: NodePublicKeyStripPrefix(registerRequest.NodeKey), - LastSeen: &now, - Expiry: &time.Time{}, - } - - if !registerRequest.Expiry.IsZero() { - log.Trace(). - Caller(). - Str("machine", registerRequest.Hostinfo.Hostname). - Time("expiry", registerRequest.Expiry). - Msg("Non-zero expiry time requested") - newMachine.Expiry = ®isterRequest.Expiry - } - - h.registrationCache.Set( - newMachine.NodeKey, - newMachine, - registerCacheExpiration, - ) - - h.handleMachineRegistrationNew(writer, req, machineKey, registerRequest) - - return - } - - // The machine is already registered, so we need to pass through reauth or key update. - if machine != nil { - // If the NodeKey stored in headscale is the same as the key presented in a registration - // request, then we have a node that is either: - // - Trying to log out (sending a expiry in the past) - // - A valid, registered machine, looking for the node map - // - Expired machine wanting to reauthenticate - if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.NodeKey) { - // The client sends an Expiry in the past if the client is requesting to expire the key (aka logout) - // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648 - if !registerRequest.Expiry.IsZero() && - registerRequest.Expiry.UTC().Before(now) { - h.handleMachineLogOut(writer, req, machineKey, *machine) - - return - } - - // If machine is not expired, and is register, we have a already accepted this machine, - // let it proceed with a valid registration - if !machine.isExpired() { - h.handleMachineValidRegistration(writer, req, machineKey, *machine) - - return - } - } - - // The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration - if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.OldNodeKey) && - !machine.isExpired() { - h.handleMachineRefreshKey( - writer, - req, - machineKey, - registerRequest, - *machine, - ) - - return - } - - // The machine has expired - h.handleMachineExpired(writer, req, machineKey, registerRequest, *machine) - - return - } -} - func (h *Headscale) getLegacyMapResponseData( machineKey key.MachinePublic, mapRequest tailcfg.MapRequest, @@ -436,42 +191,6 @@ func (h *Headscale) getLegacyMapResponseData( return data, nil } -func (h *Headscale) getMapKeepAliveResponse( - machineKey key.MachinePublic, - mapRequest tailcfg.MapRequest, -) ([]byte, error) { - mapResponse := tailcfg.MapResponse{ - KeepAlive: true, - } - var respBody []byte - var err error - if mapRequest.Compress == ZstdCompression { - src, err := json.Marshal(mapResponse) - if err != nil { - log.Error(). - Caller(). - Str("func", "getMapKeepAliveResponse"). - Err(err). - Msg("Failed to marshal keepalive response for the client") - - return nil, err - } - encoder, _ := zstd.NewWriter(nil) - srcCompressed := encoder.EncodeAll(src, nil) - respBody = h.privateKey.SealTo(machineKey, srcCompressed) - } else { - respBody, err = encode(mapResponse, &machineKey, h.privateKey) - if err != nil { - return nil, err - } - } - data := make([]byte, reservedResponseHeaderSize) - binary.LittleEndian.PutUint32(data, uint32(len(respBody))) - data = append(data, respBody...) - - return data, nil -} - func (h *Headscale) handleMachineLogOut( writer http.ResponseWriter, req *http.Request, diff --git a/protocol_common.go b/protocol_common.go new file mode 100644 index 00000000..c8eab80e --- /dev/null +++ b/protocol_common.go @@ -0,0 +1,66 @@ +package headscale + +import ( + "encoding/json" + "net/http" + "strconv" + + "github.com/rs/zerolog/log" + "tailscale.com/tailcfg" +) + +// KeyHandler provides the Headscale pub key +// Listens in /key. +func (h *Headscale) KeyHandler( + writer http.ResponseWriter, + req *http.Request, +) { + // New Tailscale clients send a 'v' parameter to indicate the CurrentCapabilityVersion + clientCapabilityStr := req.URL.Query().Get("v") + if clientCapabilityStr != "" { + clientCapabilityVersion, err := strconv.Atoi(clientCapabilityStr) + if err != nil { + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusBadRequest) + _, err := writer.Write([]byte("Wrong params")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } + + return + } + + if clientCapabilityVersion >= NoiseCapabilityVersion { + // Tailscale has a different key for the TS2021 protocol + resp := tailcfg.OverTLSPublicKeyResponse{ + LegacyPublicKey: h.privateKey.Public(), + PublicKey: h.noisePrivateKey.Public(), + } + writer.Header().Set("Content-Type", "application/json") + writer.WriteHeader(http.StatusOK) + err = json.NewEncoder(writer).Encode(resp) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } + + return + } + } + + // Old clients don't send a 'v' parameter, so we send the legacy public key + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusOK) + _, err := writer.Write([]byte(MachinePublicKeyStripPrefix(h.privateKey.Public()))) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } +} diff --git a/protocol_common_utils.go b/protocol_common_utils.go new file mode 100644 index 00000000..5939b6a4 --- /dev/null +++ b/protocol_common_utils.go @@ -0,0 +1,47 @@ +package headscale + +import ( + "encoding/binary" + "encoding/json" + + "github.com/klauspost/compress/zstd" + "github.com/rs/zerolog/log" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +func (h *Headscale) getMapKeepAliveResponse( + machineKey key.MachinePublic, + mapRequest tailcfg.MapRequest, +) ([]byte, error) { + mapResponse := tailcfg.MapResponse{ + KeepAlive: true, + } + var respBody []byte + var err error + if mapRequest.Compress == ZstdCompression { + src, err := json.Marshal(mapResponse) + if err != nil { + log.Error(). + Caller(). + Str("func", "getMapKeepAliveResponse"). + Err(err). + Msg("Failed to marshal keepalive response for the client") + + return nil, err + } + encoder, _ := zstd.NewWriter(nil) + srcCompressed := encoder.EncodeAll(src, nil) + respBody = h.privateKey.SealTo(machineKey, srcCompressed) + } else { + respBody, err = encode(mapResponse, &machineKey, h.privateKey) + if err != nil { + return nil, err + } + } + data := make([]byte, reservedResponseHeaderSize) + binary.LittleEndian.PutUint32(data, uint32(len(respBody))) + data = append(data, respBody...) + + return data, nil +}