From 53e5c05b0a60a9cca9e5805e4c16325e45a6ee8d Mon Sep 17 00:00:00 2001
From: Juan Font Alonso <juanfontalonso@gmail.com>
Date: Mon, 20 Jun 2022 12:30:51 +0200
Subject: [PATCH] Remove gin from the poll handlers

---
 poll.go | 145 +++++++++++++++++++++++++++++---------------------------
 1 file changed, 74 insertions(+), 71 deletions(-)

diff --git a/poll.go b/poll.go
index 239f260b..1d215089 100644
--- a/poll.go
+++ b/poll.go
@@ -2,13 +2,14 @@ package headscale
 
 import (
 	"context"
+	"encoding/json"
 	"errors"
 	"fmt"
 	"io"
 	"net/http"
 	"time"
 
-	"github.com/gin-gonic/gin"
+	"github.com/gorilla/mux"
 	"github.com/rs/zerolog/log"
 	"gorm.io/gorm"
 	"tailscale.com/tailcfg"
@@ -33,13 +34,25 @@ const machineNameContextKey = contextKey("machineName")
 // only after their first request (marked with the ReadOnly field).
 //
 // At this moment the updates are sent in a quite horrendous way, but they kinda work.
-func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
+func (h *Headscale) PollNetMapHandler(
+	w http.ResponseWriter,
+	r *http.Request,
+) {
+	vars := mux.Vars(r)
+	machineKeyStr, ok := vars["mkey"]
+	if !ok || machineKeyStr == "" {
+		log.Error().
+			Str("handler", "PollNetMap").
+			Msg("No machine key in request")
+		http.Error(w, "No machine key in request", http.StatusBadRequest)
+
+		return
+	}
 	log.Trace().
 		Str("handler", "PollNetMap").
-		Str("id", ctx.Param("id")).
+		Str("id", machineKeyStr).
 		Msg("PollNetMapHandler called")
-	body, _ := io.ReadAll(ctx.Request.Body)
-	machineKeyStr := ctx.Param("id")
+	body, _ := io.ReadAll(r.Body)
 
 	var machineKey key.MachinePublic
 	err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr)))
@@ -48,7 +61,8 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
 			Str("handler", "PollNetMap").
 			Err(err).
 			Msg("Cannot parse client key")
-		ctx.String(http.StatusBadRequest, "")
+
+		http.Error(w, "Cannot parse client key", http.StatusBadRequest)
 
 		return
 	}
@@ -59,7 +73,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
 			Str("handler", "PollNetMap").
 			Err(err).
 			Msg("Cannot decode message")
-		ctx.String(http.StatusBadRequest, "")
+		http.Error(w, "Cannot decode message", http.StatusBadRequest)
 
 		return
 	}
@@ -70,20 +84,21 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
 			log.Warn().
 				Str("handler", "PollNetMap").
 				Msgf("Ignoring request, cannot find machine with key %s", machineKey.String())
-			ctx.String(http.StatusUnauthorized, "")
+
+			http.Error(w, "", http.StatusUnauthorized)
 
 			return
 		}
 		log.Error().
 			Str("handler", "PollNetMap").
 			Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String())
-		ctx.String(http.StatusInternalServerError, "")
+		http.Error(w, "", http.StatusInternalServerError)
 
 		return
 	}
 	log.Trace().
 		Str("handler", "PollNetMap").
-		Str("id", ctx.Param("id")).
+		Str("id", machineKeyStr).
 		Str("machine", machine.Hostname).
 		Msg("Found machine in database")
 
@@ -120,11 +135,11 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
 		if err != nil {
 			log.Error().
 				Str("handler", "PollNetMap").
-				Str("id", ctx.Param("id")).
+				Str("id", machineKeyStr).
 				Str("machine", machine.Hostname).
 				Err(err).
 				Msg("Failed to persist/update machine in the database")
-			ctx.String(http.StatusInternalServerError, ":(")
+			http.Error(w, "", http.StatusInternalServerError)
 
 			return
 		}
@@ -134,11 +149,11 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
 	if err != nil {
 		log.Error().
 			Str("handler", "PollNetMap").
-			Str("id", ctx.Param("id")).
+			Str("id", machineKeyStr).
 			Str("machine", machine.Hostname).
 			Err(err).
 			Msg("Failed to get Map response")
-		ctx.String(http.StatusInternalServerError, ":(")
+		http.Error(w, "", http.StatusInternalServerError)
 
 		return
 	}
@@ -150,7 +165,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
 	// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
 	log.Debug().
 		Str("handler", "PollNetMap").
-		Str("id", ctx.Param("id")).
+		Str("id", machineKeyStr).
 		Str("machine", machine.Hostname).
 		Bool("readOnly", req.ReadOnly).
 		Bool("omitPeers", req.OmitPeers).
@@ -162,7 +177,10 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
 			Str("handler", "PollNetMap").
 			Str("machine", machine.Hostname).
 			Msg("Client is starting up. Probably interested in a DERP map")
-		ctx.Data(http.StatusOK, "application/json; charset=utf-8", data)
+
+		w.Header().Set("Content-Type", "application/json; charset=utf-8")
+		w.WriteHeader(http.StatusOK)
+		json.NewEncoder(w).Encode(data)
 
 		return
 	}
@@ -177,7 +195,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
 	// Only create update channel if it has not been created
 	log.Trace().
 		Str("handler", "PollNetMap").
-		Str("id", ctx.Param("id")).
+		Str("id", machineKeyStr).
 		Str("machine", machine.Hostname).
 		Msg("Loading or creating update channel")
 
@@ -194,8 +212,9 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
 			Str("handler", "PollNetMap").
 			Str("machine", machine.Hostname).
 			Msg("Client sent endpoint update and is ok with a response without peer list")
-		ctx.Data(http.StatusOK, "application/json; charset=utf-8", data)
-
+		w.Header().Set("Content-Type", "application/json; charset=utf-8")
+		w.WriteHeader(http.StatusOK)
+		json.NewEncoder(w).Encode(data)
 		// It sounds like we should update the nodes when we have received a endpoint update
 		// even tho the comments in the tailscale code dont explicitly say so.
 		updateRequestsFromNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "endpoint-update").
@@ -208,7 +227,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
 			Str("handler", "PollNetMap").
 			Str("machine", machine.Hostname).
 			Msg("Ignoring request, don't know how to handle it")
-		ctx.String(http.StatusBadRequest, "")
+		http.Error(w, "", http.StatusBadRequest)
 
 		return
 	}
@@ -232,7 +251,8 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
 	updateChan <- struct{}{}
 
 	h.PollNetMapStream(
-		ctx,
+		w,
+		r,
 		machine,
 		req,
 		machineKey,
@@ -242,7 +262,7 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
 	)
 	log.Trace().
 		Str("handler", "PollNetMap").
-		Str("id", ctx.Param("id")).
+		Str("id", machineKeyStr).
 		Str("machine", machine.Hostname).
 		Msg("Finished stream, closing PollNetMap session")
 }
@@ -251,7 +271,8 @@ func (h *Headscale) PollNetMapHandler(ctx *gin.Context) {
 // stream logic, ensuring we communicate updates and data
 // to the connected clients.
 func (h *Headscale) PollNetMapStream(
-	ctx *gin.Context,
+	w http.ResponseWriter,
+	r *http.Request,
 	machine *Machine,
 	mapRequest tailcfg.MapRequest,
 	machineKey key.MachinePublic,
@@ -259,41 +280,21 @@ func (h *Headscale) PollNetMapStream(
 	keepAliveChan chan []byte,
 	updateChan chan struct{},
 ) {
-	{
-		machine, err := h.GetMachineByMachineKey(machineKey)
-		if err != nil {
-			if errors.Is(err, gorm.ErrRecordNotFound) {
-				log.Warn().
-					Str("handler", "PollNetMap").
-					Msgf("Ignoring request, cannot find machine with key %s", machineKey.String())
-				ctx.String(http.StatusUnauthorized, "")
+	ctx := context.WithValue(context.Background(), machineNameContextKey, machine.Hostname)
 
-				return
-			}
-			log.Error().
-				Str("handler", "PollNetMap").
-				Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String())
-			ctx.String(http.StatusInternalServerError, "")
+	ctx, cancel := context.WithCancel(ctx)
+	defer cancel()
 
-			return
-		}
+	go h.scheduledPollWorker(
+		ctx,
+		updateChan,
+		keepAliveChan,
+		machineKey,
+		mapRequest,
+		machine,
+	)
 
-		ctx := context.WithValue(ctx.Request.Context(), machineNameContextKey, machine.Hostname)
-
-		ctx, cancel := context.WithCancel(ctx)
-		defer cancel()
-
-		go h.scheduledPollWorker(
-			ctx,
-			updateChan,
-			keepAliveChan,
-			machineKey,
-			mapRequest,
-			machine,
-		)
-	}
-
-	ctx.Stream(func(writer io.Writer) bool {
+	for {
 		log.Trace().
 			Str("handler", "PollNetMapStream").
 			Str("machine", machine.Hostname).
@@ -312,7 +313,7 @@ func (h *Headscale) PollNetMapStream(
 				Str("channel", "pollData").
 				Int("bytes", len(data)).
 				Msg("Sending data received via pollData channel")
-			_, err := writer.Write(data)
+			_, err := w.Write(data)
 			if err != nil {
 				log.Error().
 					Str("handler", "PollNetMapStream").
@@ -321,7 +322,7 @@ func (h *Headscale) PollNetMapStream(
 					Err(err).
 					Msg("Cannot write data")
 
-				return false
+				break
 			}
 			log.Trace().
 				Str("handler", "PollNetMapStream").
@@ -343,7 +344,7 @@ func (h *Headscale) PollNetMapStream(
 
 				// client has been removed from database
 				// since the stream opened, terminate connection.
-				return false
+				break
 			}
 			now := time.Now().UTC()
 			machine.LastSeen = &now
@@ -369,7 +370,7 @@ func (h *Headscale) PollNetMapStream(
 					Msg("Machine entry in database updated successfully after sending pollData")
 			}
 
-			return true
+			break
 
 		case data := <-keepAliveChan:
 			log.Trace().
@@ -378,7 +379,7 @@ func (h *Headscale) PollNetMapStream(
 				Str("channel", "keepAlive").
 				Int("bytes", len(data)).
 				Msg("Sending keep alive message")
-			_, err := writer.Write(data)
+			_, err := w.Write(data)
 			if err != nil {
 				log.Error().
 					Str("handler", "PollNetMapStream").
@@ -387,7 +388,7 @@ func (h *Headscale) PollNetMapStream(
 					Err(err).
 					Msg("Cannot write keep alive message")
 
-				return false
+				break
 			}
 			log.Trace().
 				Str("handler", "PollNetMapStream").
@@ -409,7 +410,7 @@ func (h *Headscale) PollNetMapStream(
 
 				// client has been removed from database
 				// since the stream opened, terminate connection.
-				return false
+				break
 			}
 			now := time.Now().UTC()
 			machine.LastSeen = &now
@@ -430,7 +431,7 @@ func (h *Headscale) PollNetMapStream(
 					Msg("Machine updated successfully after sending keep alive")
 			}
 
-			return true
+			break
 
 		case <-updateChan:
 			log.Trace().
@@ -460,7 +461,7 @@ func (h *Headscale) PollNetMapStream(
 						Err(err).
 						Msg("Could not get the map update")
 				}
-				_, err = writer.Write(data)
+				_, err = w.Write(data)
 				if err != nil {
 					log.Error().
 						Str("handler", "PollNetMapStream").
@@ -471,7 +472,7 @@ func (h *Headscale) PollNetMapStream(
 					updateRequestsSentToNode.WithLabelValues(machine.Namespace.Name, machine.Hostname, "failed").
 						Inc()
 
-					return false
+					return
 				}
 				log.Trace().
 					Str("handler", "PollNetMapStream").
@@ -499,7 +500,7 @@ func (h *Headscale) PollNetMapStream(
 
 					// client has been removed from database
 					// since the stream opened, terminate connection.
-					return false
+					return
 				}
 				now := time.Now().UTC()
 
@@ -529,9 +530,9 @@ func (h *Headscale) PollNetMapStream(
 					Msgf("%s is up to date", machine.Hostname)
 			}
 
-			return true
+			return
 
-		case <-ctx.Request.Context().Done():
+		case <-ctx.Done():
 			log.Info().
 				Str("handler", "PollNetMapStream").
 				Str("machine", machine.Hostname).
@@ -550,7 +551,7 @@ func (h *Headscale) PollNetMapStream(
 
 				// client has been removed from database
 				// since the stream opened, terminate connection.
-				return false
+				break
 			}
 			now := time.Now().UTC()
 			machine.LastSeen = &now
@@ -564,9 +565,11 @@ func (h *Headscale) PollNetMapStream(
 					Msg("Cannot update machine LastSeen")
 			}
 
-			return false
+			break
 		}
-	})
+	}
+
+	log.Info().Msgf("Closing poll loop to %s", machine.Hostname)
 }
 
 func (h *Headscale) scheduledPollWorker(