Terminate tls immediatly, mux after

This commit is contained in:
Kristoffer Dalby 2022-02-12 13:25:27 +00:00
parent 537cd35cb2
commit 8853ccd5b4

176
app.go
View File

@ -35,7 +35,7 @@ import (
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
// "google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/reflection"
@ -418,14 +418,65 @@ func (h *Headscale) ensureUnixSocketIsAbsent() error {
return os.Remove(h.cfg.UnixSocket)
}
func (h *Headscale) createRouter(grpcMux *runtime.ServeMux) *gin.Engine {
router := gin.Default()
prometheus := ginprometheus.NewPrometheus("gin")
prometheus.Use(router)
router.GET(
"/health",
func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) },
)
router.GET("/key", h.KeyHandler)
router.GET("/register", h.RegisterWebAPI)
router.POST("/machine/:id/map", h.PollNetMapHandler)
router.POST("/machine/:id", h.RegistrationHandler)
router.GET("/oidc/register/:mkey", h.RegisterOIDC)
router.GET("/oidc/callback", h.OIDCCallback)
router.GET("/apple", h.AppleMobileConfig)
router.GET("/apple/:platform", h.ApplePlatformConfig)
router.GET("/swagger", SwaggerUI)
router.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1)
api := router.Group("/api")
api.Use(h.httpAuthenticationMiddleware)
{
api.Any("/v1/*any", gin.WrapF(grpcMux.ServeHTTP))
}
router.NoRoute(stdoutHandler)
return router
}
// Serve launches a GIN server with the Headscale API.
func (h *Headscale) Serve() error {
var err error
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
// Fetch an initial DERP Map before we start serving
h.DERPMap = GetDERPMap(h.cfg.DERP)
defer cancel()
if h.cfg.DERP.AutoUpdate {
derpMapCancelChannel := make(chan struct{})
defer func() { derpMapCancelChannel <- struct{}{} }()
go h.scheduledDERPMapUpdateWorker(derpMapCancelChannel)
}
// I HATE THIS
go h.watchForKVUpdates(updateInterval)
go h.expireEphemeralNodes(updateInterval)
if zl.GlobalLevel() == zl.TraceLevel {
zerolog.RespLog = true
} else {
zerolog.RespLog = false
}
//
//
// Set up LOCAL listeners
//
err = h.ensureUnixSocketIsAbsent()
if err != nil {
@ -455,7 +506,35 @@ func (h *Headscale) Serve() error {
os.Exit(0)
}(sigc)
networkListener, err := net.Listen("tcp", h.cfg.Addr)
//
//
// Set up REMOTE listeners
//
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var networkListener net.Listener
tlsConfig, err := h.getTLSSettings()
if err != nil {
log.Error().Err(err).Msg("Failed to set up TLS configuration")
return err
}
// if tlsConfig != nil {
// httpServer.TLSConfig = tlsConfig
//
// grpcOptions = append(grpcOptions, grpc.Creds(credentials.NewTLS(tlsConfig)))
// }
if tlsConfig != nil {
networkListener, err = tls.Listen("tcp", h.cfg.Addr, tlsConfig)
} else {
networkListener, err = net.Listen("tcp", h.cfg.Addr)
}
if err != nil {
return fmt.Errorf("failed to bind to TCP address: %w", err)
}
@ -463,6 +542,7 @@ func (h *Headscale) Serve() error {
// Create the cmux object that will multiplex 2 protocols on the same port.
// The two following listeners will be served on the same port below gracefully.
networkMutex := cmux.New(networkListener)
// Match gRPC requests here
grpcListener := networkMutex.MatchWithWriters(
cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"),
@ -495,46 +575,7 @@ func (h *Headscale) Serve() error {
return err
}
router := gin.Default()
prometheus := ginprometheus.NewPrometheus("gin")
prometheus.Use(router)
router.GET(
"/health",
func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"healthy": "ok"}) },
)
router.GET("/key", h.KeyHandler)
router.GET("/register", h.RegisterWebAPI)
router.POST("/machine/:id/map", h.PollNetMapHandler)
router.POST("/machine/:id", h.RegistrationHandler)
router.GET("/oidc/register/:mkey", h.RegisterOIDC)
router.GET("/oidc/callback", h.OIDCCallback)
router.GET("/apple", h.AppleMobileConfig)
router.GET("/apple/:platform", h.ApplePlatformConfig)
router.GET("/swagger", SwaggerUI)
router.GET("/swagger/v1/openapiv2.json", SwaggerAPIv1)
api := router.Group("/api")
api.Use(h.httpAuthenticationMiddleware)
{
api.Any("/v1/*any", gin.WrapF(grpcGatewayMux.ServeHTTP))
}
router.NoRoute(stdoutHandler)
// Fetch an initial DERP Map before we start serving
h.DERPMap = GetDERPMap(h.cfg.DERP)
if h.cfg.DERP.AutoUpdate {
derpMapCancelChannel := make(chan struct{})
defer func() { derpMapCancelChannel <- struct{}{} }()
go h.scheduledDERPMapUpdateWorker(derpMapCancelChannel)
}
// I HATE THIS
go h.watchForKVUpdates(updateInterval)
go h.expireEphemeralNodes(updateInterval)
router := h.createRouter(grpcGatewayMux)
httpServer := &http.Server{
Addr: h.cfg.Addr,
@ -547,12 +588,6 @@ func (h *Headscale) Serve() error {
WriteTimeout: 0,
}
if zl.GlobalLevel() == zl.TraceLevel {
zerolog.RespLog = true
} else {
zerolog.RespLog = false
}
grpcOptions := []grpc.ServerOption{
grpc.UnaryInterceptor(
grpc_middleware.ChainUnaryServer(
@ -562,23 +597,6 @@ func (h *Headscale) Serve() error {
),
}
tlsConfig, err := h.getTLSSettings()
if err != nil {
log.Error().Err(err).Msg("Failed to set up TLS configuration")
return err
}
if tlsConfig != nil {
httpServer.TLSConfig = tlsConfig
// grpcOptions = append(grpcOptions, grpc.Creds(credentials.NewTLS(tlsConfig)))
grpcOptions = append(
grpcOptions,
grpc.Creds(credentials.NewServerTLSFromCert(&tlsConfig.Certificates[0])),
)
}
grpcServer := grpc.NewServer(grpcOptions...)
// Start the local gRPC server without TLS and without authentication
@ -592,22 +610,20 @@ func (h *Headscale) Serve() error {
errorGroup := new(errgroup.Group)
errorGroup.Go(func() error { return grpcSocket.Serve(socketListener) })
// TODO(kradalby): Verify if we need the same TLS setup for gRPC as HTTP
errorGroup.Go(func() error { return grpcServer.Serve(grpcListener) })
if tlsConfig != nil {
errorGroup.Go(func() error {
tlsl := tls.NewListener(httpListener, tlsConfig)
return httpServer.Serve(tlsl)
})
} else {
errorGroup.Go(func() error { return httpServer.Serve(httpListener) })
}
errorGroup.Go(func() error { return networkMutex.Serve() })
// if tlsConfig != nil {
// errorGroup.Go(func() error {
// tlsl := tls.NewListener(httpListener, tlsConfig)
//
// return httpServer.Serve(tlsl)
// })
// } else {
// errorGroup.Go(func() error { return httpServer.Serve(httpListener) })
// }
log.Info().
Msgf("listening and serving (multiplexed HTTP and gRPC) on: %s", h.cfg.Addr)