From 30a2ccd9758c96efa42e2f9691f151be480ec5a6 Mon Sep 17 00:00:00 2001
From: Kristoffer Dalby <kradalby@kradalby.no>
Date: Sat, 12 Feb 2022 17:05:30 +0000
Subject: [PATCH] Add tls certs as creds for grpc

---
 app.go | 116 +++++++++++++++++++++++++++++----------------------------
 1 file changed, 59 insertions(+), 57 deletions(-)

diff --git a/app.go b/app.go
index 987e64e4..8d228b48 100644
--- a/app.go
+++ b/app.go
@@ -34,6 +34,8 @@ import (
 	"golang.org/x/sync/errgroup"
 	"google.golang.org/grpc"
 	"google.golang.org/grpc/codes"
+	"google.golang.org/grpc/credentials"
+	"google.golang.org/grpc/internal/credentials"
 
 	// "google.golang.org/grpc/credentials"
 	"google.golang.org/grpc/metadata"
@@ -474,6 +476,13 @@ func (h *Headscale) Serve() error {
 		zerolog.RespLog = false
 	}
 
+	// Prepare group for running listeners
+	errorGroup := new(errgroup.Group)
+
+	ctx := context.Background()
+	ctx, cancel := context.WithCancel(ctx)
+	defer cancel()
+
 	//
 	//
 	// Set up LOCAL listeners
@@ -507,39 +516,6 @@ func (h *Headscale) Serve() error {
 		os.Exit(0)
 	}(sigc)
 
-	//
-	//
-	// Set up REMOTE listeners
-	//
-	ctx := context.Background()
-	ctx, cancel := context.WithCancel(ctx)
-
-	defer cancel()
-
-	tlsConfig, err := h.getTLSSettings()
-	if err != nil {
-		log.Error().Err(err).Msg("Failed to set up TLS configuration")
-
-		return err
-	}
-
-	// var httpListener net.Listener
-	//
-	// if tlsConfig != nil {
-	// 	httpListener, err = tls.Listen("tcp", h.cfg.Addr, tlsConfig)
-	// } else {
-	// 	httpListener, err = net.Listen("tcp", h.cfg.Addr)
-	// }
-	// if err != nil {
-	// 	return fmt.Errorf("failed to bind to TCP address: %w", err)
-	// }
-	//
-
-	//
-	//
-	// gRPC setup
-	//
-
 	grpcGatewayMux := runtime.NewServeMux()
 
 	// Make the grpc-gateway connect to grpc over socket
@@ -561,33 +537,63 @@ func (h *Headscale) Serve() error {
 		return err
 	}
 
-	grpcOptions := []grpc.ServerOption{
-		grpc.UnaryInterceptor(
-			grpc_middleware.ChainUnaryServer(
-				h.grpcAuthenticationInterceptor,
-				zerolog.NewUnaryServerInterceptor(),
-			),
-		),
-	}
-
-	grpcServer := grpc.NewServer(grpcOptions...)
-
 	// Start the local gRPC server without TLS and without authentication
 	grpcSocket := grpc.NewServer(zerolog.UnaryInterceptor())
 
-	v1.RegisterHeadscaleServiceServer(grpcServer, newHeadscaleV1APIServer(h))
 	v1.RegisterHeadscaleServiceServer(grpcSocket, newHeadscaleV1APIServer(h))
-	reflection.Register(grpcServer)
 	reflection.Register(grpcSocket)
 
-	var grpcListener net.Listener
-	if tlsConfig != nil {
-		grpcListener, err = tls.Listen("tcp", h.cfg.GRPCAddr, tlsConfig)
-	} else {
-		grpcListener, err = net.Listen("tcp", h.cfg.GRPCAddr)
-	}
+	errorGroup.Go(func() error { return grpcSocket.Serve(socketListener) })
+
+	//
+	//
+	// Set up REMOTE listeners
+	//
+
+	tlsConfig, err := h.getTLSSettings()
 	if err != nil {
-		return fmt.Errorf("failed to bind to TCP address: %w", err)
+		log.Error().Err(err).Msg("Failed to set up TLS configuration")
+
+		return err
+	}
+
+	//
+	//
+	// gRPC setup
+	//
+
+	// If TLS has been enabled, set up the remote gRPC server
+	if tlsConfig != nil {
+		log.Info().Msgf("Enabling remote gRPC at %s", h.cfg.GRPCAddr)
+
+		grpcOptions := []grpc.ServerOption{
+			grpc.UnaryInterceptor(
+				grpc_middleware.ChainUnaryServer(
+					h.grpcAuthenticationInterceptor,
+					zerolog.NewUnaryServerInterceptor(),
+				),
+			),
+			grpc.Creds(credentials.NewTLS(tlsConfig)),
+		}
+
+		grpcServer := grpc.NewServer(grpcOptions...)
+
+		v1.RegisterHeadscaleServiceServer(grpcServer, newHeadscaleV1APIServer(h))
+		reflection.Register(grpcServer)
+
+		var grpcListener net.Listener
+		// if tlsConfig != nil {
+		// 	grpcListener, err = tls.Listen("tcp", h.cfg.GRPCAddr, tlsConfig)
+		// } else {
+		grpcListener, err = net.Listen("tcp", h.cfg.GRPCAddr)
+		// }
+		if err != nil {
+			return fmt.Errorf("failed to bind to TCP address: %w", err)
+		}
+
+		errorGroup.Go(func() error { return grpcServer.Serve(grpcListener) })
+	} else {
+		log.Info().Msg("TLS is not configured, not enabling remote gRPC")
 	}
 
 	//
@@ -619,10 +625,6 @@ func (h *Headscale) Serve() error {
 		return fmt.Errorf("failed to bind to TCP address: %w", err)
 	}
 
-	errorGroup := new(errgroup.Group)
-
-	errorGroup.Go(func() error { return grpcSocket.Serve(socketListener) })
-	errorGroup.Go(func() error { return grpcServer.Serve(grpcListener) })
 	errorGroup.Go(func() error { return httpServer.Serve(httpListener) })
 
 	log.Info().