From d89fb68a7a717483d03ab18a90f1233a36775176 Mon Sep 17 00:00:00 2001 From: Juan Font Alonso Date: Sat, 18 Jun 2022 18:41:42 +0200 Subject: [PATCH] Switch to use gorilla's mux as muxer --- app.go | 137 +++++++++++++++++++++++++++---------------------- derp_server.go | 20 +++++--- go.mod | 1 + go.sum | 1 + 4 files changed, 92 insertions(+), 67 deletions(-) diff --git a/app.go b/app.go index cfed67aa..d8a4a609 100644 --- a/app.go +++ b/app.go @@ -18,6 +18,7 @@ import ( "github.com/coreos/go-oidc/v3/oidc" "github.com/gin-gonic/gin" + "github.com/gorilla/mux" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" @@ -326,48 +327,56 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context, return handler(ctx, req) } -func (h *Headscale) httpAuthenticationMiddleware(ctx *gin.Context) { - log.Trace(). - Caller(). - Str("client_address", ctx.ClientIP()). - Msg("HTTP authentication invoked") - - authHeader := ctx.GetHeader("authorization") - - if !strings.HasPrefix(authHeader, AuthPrefix) { - log.Error(). +func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func( + w http.ResponseWriter, + r *http.Request, + ) { + log.Trace(). Caller(). - Str("client_address", ctx.ClientIP()). - Msg(`missing "Bearer " prefix in "Authorization" header`) - ctx.AbortWithStatus(http.StatusUnauthorized) + Str("client_address", r.RemoteAddr). + Msg("HTTP authentication invoked") - return - } + authHeader := r.Header.Get("X-Session-Token") - valid, err := h.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix)) - if err != nil { - log.Error(). - Caller(). - Err(err). - Str("client_address", ctx.ClientIP()). - Msg("failed to validate token") + if !strings.HasPrefix(authHeader, AuthPrefix) { + log.Error(). + Caller(). + Str("client_address", r.RemoteAddr). + Msg(`missing "Bearer " prefix in "Authorization" header`) + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("Unauthorized")) - ctx.AbortWithStatus(http.StatusInternalServerError) + return + } - return - } + valid, err := h.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix)) + if err != nil { + log.Error(). + Caller(). + Err(err). + Str("client_address", r.RemoteAddr). + Msg("failed to validate token") - if !valid { - log.Info(). - Str("client_address", ctx.ClientIP()). - Msg("invalid token") + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("Unauthorized")) - ctx.AbortWithStatus(http.StatusUnauthorized) + return + } - return - } + if !valid { + log.Info(). + Str("client_address", r.RemoteAddr). + Msg("invalid token") - ctx.Next() + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("Unauthorized")) + + return + } + + next.ServeHTTP(w, r) + }) } // ensureUnixSocketIsAbsent will check if the given path for headscales unix socket is clear @@ -390,39 +399,42 @@ func (h *Headscale) createPrometheusRouter() *gin.Engine { return promRouter } -func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *gin.Engine { - router := gin.Default() +func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *mux.Router { + router := mux.NewRouter() - router.GET( + router.HandleFunc( "/health", - func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) }, - ) - router.GET("/key", gin.WrapF(h.KeyHandler)) - router.GET("/register", gin.WrapF(h.RegisterWebAPI)) - router.POST("/machine/:id/map", h.PollNetMapHandler) - router.POST("/machine/:id", h.RegistrationHandler) - router.GET("/oidc/register/:mkey", h.RegisterOIDC) - router.GET("/oidc/callback", gin.WrapF(h.OIDCCallback)) - router.GET("/apple", gin.WrapF(h.AppleConfigMessage)) - router.GET("/apple/:platform", gin.WrapF(h.ApplePlatformConfig)) - router.GET("/windows", gin.WrapF(h.WindowsConfigMessage)) - router.GET("/windows/tailscale.reg", gin.WrapF(h.WindowsRegConfig)) - router.GET("/swagger", gin.WrapF(SwaggerUI)) - router.GET("/swagger/v1/openapiv2.json", gin.WrapF(SwaggerAPIv1)) + func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("{\"healthy\": \"ok\"}")) + }).Methods(http.MethodGet) + + router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet) + router.HandleFunc("/register", h.RegisterWebAPI).Methods(http.MethodGet) + router.HandleFunc("/machine/:id/map", h.PollNetMapHandler).Methods(http.MethodPost) + router.HandleFunc("/machine/:id", h.RegistrationHandler).Methods(http.MethodPost) + router.HandleFunc("/oidc/register/:mkey", h.RegisterOIDC).Methods(http.MethodGet) + router.HandleFunc("/oidc/callback", h.OIDCCallback).Methods(http.MethodGet) + router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet) + router.HandleFunc("/apple/:platform", h.ApplePlatformConfig).Methods(http.MethodGet) + router.HandleFunc("/windows", h.WindowsConfigMessage).Methods(http.MethodGet) + router.HandleFunc("/windows/tailscale.reg", h.WindowsRegConfig).Methods(http.MethodGet) + router.HandleFunc("/swagger", SwaggerUI).Methods(http.MethodGet) + router.HandleFunc("/swagger/v1/openapiv2.json", SwaggerAPIv1).Methods(http.MethodGet) if h.cfg.DERP.ServerEnabled { - router.Any("/derp", h.DERPHandler) - router.Any("/derp/probe", h.DERPProbeHandler) - router.Any("/bootstrap-dns", h.DERPBootstrapDNSHandler) + router.HandleFunc("/derp", h.DERPHandler) + router.HandleFunc("/derp/probe", h.DERPProbeHandler) + router.HandleFunc("/bootstrap-dns", h.DERPBootstrapDNSHandler) } - api := router.Group("/api") + api := router.PathPrefix("/api").Subrouter() api.Use(h.httpAuthenticationMiddleware) { - api.Any("/v1/*any", gin.WrapF(grpcMux.ServeHTTP)) + api.HandleFunc("/v1/*any", grpcMux.ServeHTTP) } - router.NoRoute(stdoutHandler) + router.PathPrefix("/").HandlerFunc(stdoutHandler) return router } @@ -811,13 +823,16 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time { } } -func stdoutHandler(ctx *gin.Context) { - body, _ := io.ReadAll(ctx.Request.Body) +func stdoutHandler( + w http.ResponseWriter, + r *http.Request, +) { + body, _ := io.ReadAll(r.Body) log.Trace(). - Interface("header", ctx.Request.Header). - Interface("proto", ctx.Request.Proto). - Interface("url", ctx.Request.URL). + Interface("header", r.Header). + Interface("proto", r.Proto). + Interface("url", r.URL). Bytes("body", body). Msg("Request did not match") } diff --git a/derp_server.go b/derp_server.go index d6fb47de..757dad56 100644 --- a/derp_server.go +++ b/derp_server.go @@ -10,7 +10,6 @@ import ( "strings" "time" - "github.com/gin-gonic/gin" "github.com/rs/zerolog/log" "tailscale.com/derp" "tailscale.com/net/stun" @@ -90,7 +89,10 @@ func (h *Headscale) generateRegionLocalDERP() (tailcfg.DERPRegion, error) { return localDERPregion, nil } -func (h *Headscale) DERPHandler(ctx *gin.Context) { +func (h *Headscale) DERPHandler( + w http.ResponseWriter, + r *http.Request, +) { log.Trace().Caller().Msgf("/derp request from %v", ctx.ClientIP()) up := strings.ToLower(ctx.Request.Header.Get("Upgrade")) if up != "websocket" && up != "derp" { @@ -143,7 +145,10 @@ func (h *Headscale) DERPHandler(ctx *gin.Context) { // DERPProbeHandler is the endpoint that js/wasm clients hit to measure // DERP latency, since they can't do UDP STUN queries. -func (h *Headscale) DERPProbeHandler(ctx *gin.Context) { +func (h *Headscale) DERPProbeHandler( + w http.ResponseWriter, + r *http.Request, +) { switch ctx.Request.Method { case "HEAD", "GET": ctx.Writer.Header().Set("Access-Control-Allow-Origin", "*") @@ -159,15 +164,18 @@ func (h *Headscale) DERPProbeHandler(ctx *gin.Context) { // The initial implementation is here https://github.com/tailscale/tailscale/pull/1406 // They have a cache, but not clear if that is really necessary at Headscale, uh, scale. // An example implementation is found here https://derp.tailscale.com/bootstrap-dns -func (h *Headscale) DERPBootstrapDNSHandler(ctx *gin.Context) { +func (h *Headscale) DERPBootstrapDNSHandler( + w http.ResponseWriter, + r *http.Request, +) { dnsEntries := make(map[string][]net.IP) resolvCtx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - var r net.Resolver + var resolver net.Resolver for _, region := range h.DERPMap.Regions { for _, node := range region.Nodes { // we don't care if we override some nodes - addrs, err := r.LookupIP(resolvCtx, "ip", node.HostName) + addrs, err := resolver.LookupIP(resolvCtx, "ip", node.HostName) if err != nil { log.Trace(). Caller(). diff --git a/go.mod b/go.mod index 70662579..98cde7ef 100644 --- a/go.mod +++ b/go.mod @@ -73,6 +73,7 @@ require ( github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect github.com/google/uuid v1.3.0 // indirect github.com/gookit/color v1.5.0 // indirect + github.com/gorilla/mux v1.8.0 // indirect github.com/hashicorp/go-version v1.4.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/imdario/mergo v0.3.12 // indirect diff --git a/go.sum b/go.sum index 9423737e..b4d03b90 100644 --- a/go.sum +++ b/go.sum @@ -403,6 +403,7 @@ github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORR github.com/gordonklaus/ineffassign v0.0.0-20200309095847-7953dde2c7bf/go.mod h1:cuNKsD1zp2v6XfE/orVX2QE1LC+i254ceGcVeDT3pTU= github.com/gordonklaus/ineffassign v0.0.0-20210225214923-2e10b2664254/go.mod h1:M9mZEtGIsR1oDaZagNPNG9iq9n2HrhZ17dsXk73V3Lw= github.com/gorhill/cronexpr v0.0.0-20180427100037-88b0669f7d75/go.mod h1:g2644b03hfBX9Ov0ZBDgXXens4rxSxmqFBbhvKv2yVA= +github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=