diff --git a/cmd/yggdrasil/main.go b/cmd/yggdrasil/main.go index 11bda9f2..3d5eab97 100644 --- a/cmd/yggdrasil/main.go +++ b/cmd/yggdrasil/main.go @@ -14,6 +14,7 @@ import ( "os/signal" "regexp" "strings" + "sync" "syscall" "golang.org/x/text/encoding/unicode" @@ -183,8 +184,7 @@ func getArgs() yggArgs { } // The main function is responsible for configuring and starting Yggdrasil. -func run(args yggArgs, ctx context.Context, done chan struct{}) { - defer close(done) +func run(args yggArgs, ctx context.Context) { // Create a new logger that logs output to stdout. var logger *log.Logger switch args.logto { @@ -371,14 +371,11 @@ func run(args yggArgs, ctx context.Context, done chan struct{}) { logger.Infof("Your public key is %s", hex.EncodeToString(public[:])) logger.Infof("Your IPv6 address is %s", address.String()) logger.Infof("Your IPv6 subnet is %s", subnet.String()) - // Catch interrupts from the operating system to exit gracefully. - <-ctx.Done() - // Capture the service being stopped on Windows. - minwinsvc.SetOnExit(n.shutdown) - n.shutdown() -} -func (n *node) shutdown() { + // Block until we are told to shut down. + <-ctx.Done() + + // Shut down the node. _ = n.admin.Stop() _ = n.multicast.Stop() _ = n.tun.Stop() @@ -387,24 +384,19 @@ func (n *node) shutdown() { func main() { args := getArgs() - hup := make(chan os.Signal, 1) - //signal.Notify(hup, os.Interrupt, syscall.SIGHUP) - term := make(chan os.Signal, 1) - signal.Notify(term, os.Interrupt, syscall.SIGTERM) - for { - done := make(chan struct{}) - ctx, cancel := context.WithCancel(context.Background()) - go run(args, ctx, done) - select { - case <-hup: - cancel() - <-done - case <-term: - cancel() - <-done - return - case <-done: - return - } - } + + // Catch interrupts from the operating system to exit gracefully. + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + + // Capture the service being stopped on Windows. + minwinsvc.SetOnExit(cancel) + + // Start the node, block and then wait for it to shut down. + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + run(args, ctx) + }() + wg.Wait() }