Working on common codebase for poll, starting with legacy

This commit is contained in:
Juan Font Alonso 2022-08-14 22:57:03 +02:00
parent f4bab6b290
commit df8ecdb603
2 changed files with 121 additions and 101 deletions

View File

@ -2,17 +2,12 @@ package headscale
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io"
"net/http" "net/http"
"time" "time"
"github.com/gorilla/mux"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/key"
) )
const ( const (
@ -23,83 +18,13 @@ type contextKey string
const machineNameContextKey = contextKey("machineName") const machineNameContextKey = contextKey("machineName")
// PollNetMapHandler takes care of /machine/:id/map func (h *Headscale) handlePollCommon(
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
//
// The clients POST stuff like HostInfo and their Endpoints here, but
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) PollNetMapHandler(
writer http.ResponseWriter, writer http.ResponseWriter,
req *http.Request, req *http.Request,
machine *Machine,
mapRequest tailcfg.MapRequest,
isNoise bool,
) { ) {
vars := mux.Vars(req)
machineKeyStr, ok := vars["mkey"]
if !ok || machineKeyStr == "" {
log.Error().
Str("handler", "PollNetMap").
Msg("No machine key in request")
http.Error(writer, "No machine key in request", http.StatusBadRequest)
return
}
log.Trace().
Str("handler", "PollNetMap").
Str("id", machineKeyStr).
Msg("PollNetMapHandler called")
body, _ := io.ReadAll(req.Body)
var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr)))
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot parse client key")
http.Error(writer, "Cannot parse client key", http.StatusBadRequest)
return
}
mapRequest := tailcfg.MapRequest{}
err = decode(body, &mapRequest, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot decode message")
http.Error(writer, "Cannot decode message", http.StatusBadRequest)
return
}
machine, err := h.GetMachineByMachineKey(machineKey)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", machineKey.String())
http.Error(writer, "", http.StatusUnauthorized)
return
}
log.Error().
Str("handler", "PollNetMap").
Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String())
http.Error(writer, "", http.StatusInternalServerError)
return
}
log.Trace().
Str("handler", "PollNetMap").
Str("id", machineKeyStr).
Str("machine", machine.Hostname).
Msg("Found machine in database")
machine.Hostname = mapRequest.Hostinfo.Hostname machine.Hostname = mapRequest.Hostinfo.Hostname
machine.HostInfo = HostInfo(*mapRequest.Hostinfo) machine.HostInfo = HostInfo(*mapRequest.Hostinfo)
machine.DiscoKey = DiscoPublicKeyStripPrefix(mapRequest.DiscoKey) machine.DiscoKey = DiscoPublicKeyStripPrefix(mapRequest.DiscoKey)
@ -107,7 +32,7 @@ func (h *Headscale) PollNetMapHandler(
// update ACLRules with peer informations (to update server tags if necessary) // update ACLRules with peer informations (to update server tags if necessary)
if h.aclPolicy != nil { if h.aclPolicy != nil {
err = h.UpdateACLRules() err := h.UpdateACLRules()
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -133,7 +58,7 @@ func (h *Headscale) PollNetMapHandler(
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", machineKeyStr). Str("node_key", machine.NodeKey).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Err(err). Err(err).
Msg("Failed to persist/update machine in the database") Msg("Failed to persist/update machine in the database")
@ -143,11 +68,11 @@ func (h *Headscale) PollNetMapHandler(
} }
} }
data, err := h.getLegacyMapResponseData(machineKey, mapRequest, machine) mapResp, err := h.getMapResponseData(mapRequest, machine, isNoise)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", machineKeyStr). Str("node_key", machine.NodeKey).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Err(err). Err(err).
Msg("Failed to get Map response") Msg("Failed to get Map response")
@ -163,7 +88,6 @@ func (h *Headscale) PollNetMapHandler(
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696 // Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
log.Debug(). log.Debug().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", machineKeyStr).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Bool("readOnly", mapRequest.ReadOnly). Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers). Bool("omitPeers", mapRequest.OmitPeers).
@ -178,7 +102,7 @@ func (h *Headscale) PollNetMapHandler(
writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
_, err := writer.Write(data) _, err := writer.Write(mapResp)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -186,6 +110,10 @@ func (h *Headscale) PollNetMapHandler(
Msg("Failed to write response") Msg("Failed to write response")
} }
if f, ok := writer.(http.Flusher); ok {
f.Flush()
}
return return
} }
@ -198,8 +126,7 @@ func (h *Headscale) PollNetMapHandler(
// Only create update channel if it has not been created // Only create update channel if it has not been created
log.Trace(). log.Trace().
Str("handler", "PollNetMap"). Caller().
Str("id", machineKeyStr).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("Loading or creating update channel") Msg("Loading or creating update channel")
@ -218,7 +145,7 @@ func (h *Headscale) PollNetMapHandler(
Msg("Client sent endpoint update and is ok with a response without peer list") Msg("Client sent endpoint update and is ok with a response without peer list")
writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
_, err := writer.Write(data) _, err := writer.Write(mapResp)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -250,7 +177,7 @@ func (h *Headscale) PollNetMapHandler(
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("Sending initial map") Msg("Sending initial map")
pollDataChan <- data pollDataChan <- mapResp
log.Info(). log.Info().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
@ -260,35 +187,34 @@ func (h *Headscale) PollNetMapHandler(
Inc() Inc()
updateChan <- struct{}{} updateChan <- struct{}{}
h.PollNetMapStream( h.pollNetMapStream(
writer, writer,
req, req,
machine, machine,
mapRequest, mapRequest,
machineKey,
pollDataChan, pollDataChan,
keepAliveChan, keepAliveChan,
updateChan, updateChan,
isNoise,
) )
log.Trace(). log.Trace().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", machineKeyStr).
Str("machine", machine.Hostname). Str("machine", machine.Hostname).
Msg("Finished stream, closing PollNetMap session") Msg("Finished stream, closing PollNetMap session")
} }
// PollNetMapStream takes care of /machine/:id/map // pollNetMapStream stream logic for /machine/map,
// stream logic, ensuring we communicate updates and data // ensuring we communicate updates and data to the connected clients.
// to the connected clients. func (h *Headscale) pollNetMapStream(
func (h *Headscale) PollNetMapStream(
writer http.ResponseWriter, writer http.ResponseWriter,
req *http.Request, req *http.Request,
machine *Machine, machine *Machine,
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
machineKey key.MachinePublic,
pollDataChan chan []byte, pollDataChan chan []byte,
keepAliveChan chan []byte, keepAliveChan chan []byte,
updateChan chan struct{}, updateChan chan struct{},
isNoise bool,
) { ) {
h.pollNetMapStreamWG.Add(1) h.pollNetMapStreamWG.Add(1)
defer h.pollNetMapStreamWG.Done() defer h.pollNetMapStreamWG.Done()
@ -302,9 +228,9 @@ func (h *Headscale) PollNetMapStream(
ctx, ctx,
updateChan, updateChan,
keepAliveChan, keepAliveChan,
machineKey,
mapRequest, mapRequest,
machine, machine,
isNoise,
) )
log.Trace(). log.Trace().
@ -491,7 +417,7 @@ func (h *Headscale) PollNetMapStream(
Time("last_successful_update", lastUpdate). Time("last_successful_update", lastUpdate).
Time("last_state_change", h.getLastStateChange(machine.Namespace.Name)). Time("last_state_change", h.getLastStateChange(machine.Namespace.Name)).
Msgf("There has been updates since the last successful update to %s", machine.Hostname) Msgf("There has been updates since the last successful update to %s", machine.Hostname)
data, err := h.getLegacyMapResponseData(machineKey, mapRequest, machine) data, err := h.getMapResponseData(mapRequest, machine, false)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
@ -637,9 +563,9 @@ func (h *Headscale) scheduledPollWorker(
ctx context.Context, ctx context.Context,
updateChan chan struct{}, updateChan chan struct{},
keepAliveChan chan []byte, keepAliveChan chan []byte,
machineKey key.MachinePublic,
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
machine *Machine, machine *Machine,
isNoise bool,
) { ) {
keepAliveTicker := time.NewTicker(keepAliveInterval) keepAliveTicker := time.NewTicker(keepAliveInterval)
updateCheckerTicker := time.NewTicker(h.cfg.NodeUpdateCheckInterval) updateCheckerTicker := time.NewTicker(h.cfg.NodeUpdateCheckInterval)
@ -661,7 +587,7 @@ func (h *Headscale) scheduledPollWorker(
return return
case <-keepAliveTicker.C: case <-keepAliveTicker.C:
data, err := h.getMapKeepAliveResponse(machineKey, mapRequest) data, err := h.getMapKeepAliveResponseData(mapRequest, machine, isNoise)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("func", "keepAlive"). Str("func", "keepAlive").

94
protocol_legacy_poll.go Normal file
View File

@ -0,0 +1,94 @@
package headscale
import (
"errors"
"io"
"net/http"
"github.com/gorilla/mux"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
// PollNetMapHandler takes care of /machine/:id/map
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
//
// The clients POST stuff like HostInfo and their Endpoints here, but
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) PollNetMapHandler(
writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
machineKeyStr, ok := vars["mkey"]
if !ok || machineKeyStr == "" {
log.Error().
Str("handler", "PollNetMap").
Msg("No machine key in request")
http.Error(writer, "No machine key in request", http.StatusBadRequest)
return
}
log.Trace().
Str("handler", "PollNetMap").
Str("id", machineKeyStr).
Msg("PollNetMapHandler called")
body, _ := io.ReadAll(req.Body)
var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr)))
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot parse client key")
http.Error(writer, "Cannot parse client key", http.StatusBadRequest)
return
}
mapRequest := tailcfg.MapRequest{}
err = decode(body, &mapRequest, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot decode message")
http.Error(writer, "Cannot decode message", http.StatusBadRequest)
return
}
machine, err := h.GetMachineByMachineKey(machineKey)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", machineKey.String())
http.Error(writer, "", http.StatusUnauthorized)
return
}
log.Error().
Str("handler", "PollNetMap").
Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String())
http.Error(writer, "", http.StatusInternalServerError)
return
}
log.Trace().
Str("handler", "PollNetMap").
Str("id", machineKeyStr).
Str("machine", machine.Hostname).
Msg("Found machine in database")
h.handlePollCommon(writer, req, machine, mapRequest, false)
}