diff --git a/cmd/yggdrasil/main.go b/cmd/yggdrasil/main.go index 55e6b261..84b0b11e 100644 --- a/cmd/yggdrasil/main.go +++ b/cmd/yggdrasil/main.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "context" "crypto/ed25519" "encoding/hex" "encoding/json" @@ -41,15 +42,15 @@ type node struct { admin *admin.AdminSocket } -func readConfig(log *log.Logger, useconf *bool, useconffile *string, normaliseconf *bool) *config.NodeConfig { +func readConfig(log *log.Logger, useconf bool, useconffile string, normaliseconf bool) *config.NodeConfig { // Use a configuration file. If -useconf, the configuration will be read // from stdin. If -useconffile, the configuration will be read from the // filesystem. var conf []byte var err error - if *useconffile != "" { + if useconffile != "" { // Read the file from the filesystem - conf, err = ioutil.ReadFile(*useconffile) + conf, err = ioutil.ReadFile(useconffile) } else { // Read the file from stdin. conf, err = ioutil.ReadAll(os.Stdin) @@ -180,9 +181,21 @@ func setLogLevel(loglevel string, logger *log.Logger) { } } -// The main function is responsible for configuring and starting Yggdrasil. -func main() { - // Configure the command line parameters. +type yggArgs struct { + genconf bool + useconf bool + useconffile string + normaliseconf bool + confjson bool + autoconf bool + ver bool + logto string + getaddr bool + getsnet bool + loglevel string +} + +func getArgs() yggArgs { genconf := flag.Bool("genconf", false, "print a new config to stdout") useconf := flag.Bool("useconf", false, "read HJSON/JSON config from stdin") useconffile := flag.String("useconffile", "", "read HJSON/JSON config from specified file path") @@ -195,10 +208,43 @@ func main() { getsnet := flag.Bool("subnet", false, "returns the IPv6 subnet as derived from the supplied configuration") loglevel := flag.String("loglevel", "info", "loglevel to enable") flag.Parse() + return yggArgs{ + genconf: *genconf, + useconf: *useconf, + useconffile: *useconffile, + normaliseconf: *normaliseconf, + confjson: *confjson, + autoconf: *autoconf, + ver: *ver, + logto: *logto, + getaddr: *getaddr, + getsnet: *getsnet, + loglevel: *loglevel, + } +} + +// The main function is responsible for configuring and starting Yggdrasil. +func run(args yggArgs, ctx context.Context, done chan struct{}) { + defer close(done) + // Configure the command line parameters. + /* + genconf := flag.Bool("genconf", false, "print a new config to stdout") + useconf := flag.Bool("useconf", false, "read HJSON/JSON config from stdin") + useconffile := flag.String("useconffile", "", "read HJSON/JSON config from specified file path") + normaliseconf := flag.Bool("normaliseconf", false, "use in combination with either -useconf or -useconffile, outputs your configuration normalised") + confjson := flag.Bool("json", false, "print configuration from -genconf or -normaliseconf as JSON instead of HJSON") + autoconf := flag.Bool("autoconf", false, "automatic mode (dynamic IP, peer with IPv6 neighbors)") + ver := flag.Bool("version", false, "prints the version of this build") + logto := flag.String("logto", "stdout", "file path to log to, \"syslog\" or \"stdout\"") + getaddr := flag.Bool("address", false, "returns the IPv6 address as derived from the supplied configuration") + getsnet := flag.Bool("subnet", false, "returns the IPv6 subnet as derived from the supplied configuration") + loglevel := flag.String("loglevel", "info", "loglevel to enable") + flag.Parse() + */ // Create a new logger that logs output to stdout. var logger *log.Logger - switch *logto { + switch args.logto { case "stdout": logger = log.New(os.Stdout, "", log.Flags()) case "syslog": @@ -206,7 +252,7 @@ func main() { logger = log.New(syslogger, "", log.Flags()) } default: - if logfd, err := os.OpenFile(*logto, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644); err == nil { + if logfd, err := os.OpenFile(args.logto, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644); err == nil { logger = log.New(logfd, "", log.Flags()) } } @@ -215,33 +261,33 @@ func main() { logger.Warnln("Logging defaulting to stdout") } - if *normaliseconf { + if args.normaliseconf { setLogLevel("error", logger) } else { - setLogLevel(*loglevel, logger) + setLogLevel(args.loglevel, logger) } var cfg *config.NodeConfig var err error switch { - case *ver: + case args.ver: fmt.Println("Build name:", version.BuildName()) fmt.Println("Build version:", version.BuildVersion()) return - case *autoconf: + case args.autoconf: // Use an autoconf-generated config, this will give us random keys and // port numbers, and will use an automatically selected TUN/TAP interface. cfg = defaults.GenerateConfig() - case *useconffile != "" || *useconf: + case args.useconffile != "" || args.useconf: // Read the configuration from either stdin or from the filesystem - cfg = readConfig(logger, useconf, useconffile, normaliseconf) + cfg = readConfig(logger, args.useconf, args.useconffile, args.normaliseconf) // If the -normaliseconf option was specified then remarshal the above // configuration and print it back to stdout. This lets the user update // their configuration file with newly mapped names (like above) or to // convert from plain JSON to commented HJSON. - if *normaliseconf { + if args.normaliseconf { var bs []byte - if *confjson { + if args.confjson { bs, err = json.MarshalIndent(cfg, "", " ") } else { bs, err = hjson.Marshal(cfg) @@ -252,9 +298,9 @@ func main() { fmt.Println(string(bs)) return } - case *genconf: + case args.genconf: // Generate a new configuration and print it to stdout. - fmt.Println(doGenconf(*confjson)) + fmt.Println(doGenconf(args.confjson)) default: // No flags were provided, therefore print the list of flags to stdout. flag.PrintDefaults() @@ -273,14 +319,14 @@ func main() { return nil } switch { - case *getaddr: + case args.getaddr: if key := getNodeKey(); key != nil { addr := address.AddrForKey(key) ip := net.IP(addr[:]) fmt.Println(ip.String()) } return - case *getsnet: + case args.getsnet: if key := getNodeKey(); key != nil { snet := address.SubnetForKey(key) ipnet := net.IPNet{ @@ -337,10 +383,8 @@ func main() { logger.Infof("Your IPv6 address is %s", address.String()) logger.Infof("Your IPv6 subnet is %s", subnet.String()) // Catch interrupts from the operating system to exit gracefully. - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) + <-ctx.Done() // Capture the service being stopped on Windows. - <-c minwinsvc.SetOnExit(n.shutdown) n.shutdown() } @@ -351,3 +395,25 @@ func (n *node) shutdown() { _ = n.tuntap.Stop() n.core.Stop() } + +func main() { + args := getArgs() + hup := make(chan os.Signal, 1) + signal.Notify(hup, os.Interrupt, syscall.SIGHUP) + term := make(chan os.Signal, 1) + signal.Notify(term, os.Interrupt, syscall.SIGTERM) + for { + done := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + go run(args, ctx, done) + select { + case <-hup: + cancel() + <-done + case <-term: + cancel() + <-done + return + } + } +}