Split yggdrasilctl code into separate functions (refactoring) (#815)

* Move yggdrasilctl responses to separate functions

* Move yggdrasilctl request switch to separate function

* Add empty lines

* Create struct CmdLine for yggdrasilctl

* Move yggdrasilctl command line parsing to separate func

* Turn struct CmdLine into CmdLineEnv

* Rename func parseCmdLine to parseFlagsAndArgs

* Move yggdrasilctl endpoint setting logic into separate func

* Function to create yggdrasilctl CmdLineEnv

* Reorder code

* Move struct fields into lines

* Turn yggdrasilctl CmdLineEnv funcs to methods

* Move yggdrasilctl connection code to separate func

* Rename functions

* Move yggdrasilctl command line env to separate mod

* Move yggdrasilctl command line env to main mod

* Run goimports

Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
This commit is contained in:
Alex Kotov 2021-08-03 02:47:38 +05:00 committed by GitHub
parent b333c7d7f3
commit cbb6dc1b7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 422 additions and 323 deletions

View File

@ -0,0 +1,94 @@
package main
import (
"bytes"
"flag"
"fmt"
"io/ioutil"
"log"
"os"
"github.com/hjson/hjson-go"
"golang.org/x/text/encoding/unicode"
"github.com/yggdrasil-network/yggdrasil-go/src/defaults"
)
type CmdLineEnv struct {
args []string
endpoint, server string
injson, verbose, ver bool
}
func newCmdLineEnv() CmdLineEnv {
var cmdLineEnv CmdLineEnv
cmdLineEnv.endpoint = defaults.GetDefaults().DefaultAdminListen
return cmdLineEnv
}
func (cmdLineEnv *CmdLineEnv) parseFlagsAndArgs() {
flag.Usage = func() {
fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [options] command [key=value] [key=value] ...\n\n", os.Args[0])
fmt.Println("Options:")
flag.PrintDefaults()
fmt.Println()
fmt.Println("Please note that options must always specified BEFORE the command\non the command line or they will be ignored.")
fmt.Println()
fmt.Println("Commands:\n - Use \"list\" for a list of available commands")
fmt.Println()
fmt.Println("Examples:")
fmt.Println(" - ", os.Args[0], "list")
fmt.Println(" - ", os.Args[0], "getPeers")
fmt.Println(" - ", os.Args[0], "-v getSelf")
fmt.Println(" - ", os.Args[0], "setTunTap name=auto mtu=1500 tap_mode=false")
fmt.Println(" - ", os.Args[0], "-endpoint=tcp://localhost:9001 getDHT")
fmt.Println(" - ", os.Args[0], "-endpoint=unix:///var/run/ygg.sock getDHT")
}
server := flag.String("endpoint", cmdLineEnv.endpoint, "Admin socket endpoint")
injson := flag.Bool("json", false, "Output in JSON format (as opposed to pretty-print)")
verbose := flag.Bool("v", false, "Verbose output (includes public keys)")
ver := flag.Bool("version", false, "Prints the version of this build")
flag.Parse()
cmdLineEnv.args = flag.Args()
cmdLineEnv.server = *server
cmdLineEnv.injson = *injson
cmdLineEnv.verbose = *verbose
cmdLineEnv.ver = *ver
}
func (cmdLineEnv *CmdLineEnv) setEndpoint(logger *log.Logger) {
if cmdLineEnv.server == cmdLineEnv.endpoint {
if config, err := ioutil.ReadFile(defaults.GetDefaults().DefaultConfigFile); err == nil {
if bytes.Equal(config[0:2], []byte{0xFF, 0xFE}) ||
bytes.Equal(config[0:2], []byte{0xFE, 0xFF}) {
utf := unicode.UTF16(unicode.BigEndian, unicode.UseBOM)
decoder := utf.NewDecoder()
config, err = decoder.Bytes(config)
if err != nil {
panic(err)
}
}
var dat map[string]interface{}
if err := hjson.Unmarshal(config, &dat); err != nil {
panic(err)
}
if ep, ok := dat["AdminListen"].(string); ok && (ep != "none" && ep != "") {
cmdLineEnv.endpoint = ep
logger.Println("Found platform default config file", defaults.GetDefaults().DefaultConfigFile)
logger.Println("Using endpoint", cmdLineEnv.endpoint, "from AdminListen")
} else {
logger.Println("Configuration file doesn't contain appropriate AdminListen option")
logger.Println("Falling back to platform default", defaults.GetDefaults().DefaultAdminListen)
}
} else {
logger.Println("Can't open config file from default location", defaults.GetDefaults().DefaultConfigFile)
logger.Println("Falling back to platform default", defaults.GetDefaults().DefaultAdminListen)
}
} else {
cmdLineEnv.endpoint = cmdLineEnv.server
logger.Println("Using endpoint", cmdLineEnv.endpoint, "from command line")
}
}

View File

@ -6,7 +6,6 @@ import (
"errors"
"flag"
"fmt"
"io/ioutil"
"log"
"net"
"net/url"
@ -15,10 +14,6 @@ import (
"strconv"
"strings"
"golang.org/x/text/encoding/unicode"
"github.com/hjson/hjson-go"
"github.com/yggdrasil-network/yggdrasil-go/src/defaults"
"github.com/yggdrasil-network/yggdrasil-go/src/version"
)
@ -32,6 +27,7 @@ func main() {
func run() int {
logbuffer := &bytes.Buffer{}
logger := log.New(logbuffer, "", log.Flags())
defer func() int {
if r := recover(); r != nil {
logger.Println("Fatal error:", r)
@ -41,97 +37,24 @@ func run() int {
return 0
}()
endpoint := defaults.GetDefaults().DefaultAdminListen
cmdLineEnv := newCmdLineEnv()
cmdLineEnv.parseFlagsAndArgs()
flag.Usage = func() {
fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [options] command [key=value] [key=value] ...\n\n", os.Args[0])
fmt.Println("Options:")
flag.PrintDefaults()
fmt.Println()
fmt.Println("Please note that options must always specified BEFORE the command\non the command line or they will be ignored.")
fmt.Println()
fmt.Println("Commands:\n - Use \"list\" for a list of available commands")
fmt.Println()
fmt.Println("Examples:")
fmt.Println(" - ", os.Args[0], "list")
fmt.Println(" - ", os.Args[0], "getPeers")
fmt.Println(" - ", os.Args[0], "-v getSelf")
fmt.Println(" - ", os.Args[0], "setTunTap name=auto mtu=1500 tap_mode=false")
fmt.Println(" - ", os.Args[0], "-endpoint=tcp://localhost:9001 getDHT")
fmt.Println(" - ", os.Args[0], "-endpoint=unix:///var/run/ygg.sock getDHT")
}
server := flag.String("endpoint", endpoint, "Admin socket endpoint")
injson := flag.Bool("json", false, "Output in JSON format (as opposed to pretty-print)")
verbose := flag.Bool("v", false, "Verbose output (includes public keys)")
ver := flag.Bool("version", false, "Prints the version of this build")
flag.Parse()
args := flag.Args()
if *ver {
if cmdLineEnv.ver {
fmt.Println("Build name:", version.BuildName())
fmt.Println("Build version:", version.BuildVersion())
fmt.Println("To get the version number of the running Yggdrasil node, run", os.Args[0], "getSelf")
return 0
}
if len(args) == 0 {
if len(cmdLineEnv.args) == 0 {
flag.Usage()
return 0
}
if *server == endpoint {
if config, err := ioutil.ReadFile(defaults.GetDefaults().DefaultConfigFile); err == nil {
if bytes.Equal(config[0:2], []byte{0xFF, 0xFE}) ||
bytes.Equal(config[0:2], []byte{0xFE, 0xFF}) {
utf := unicode.UTF16(unicode.BigEndian, unicode.UseBOM)
decoder := utf.NewDecoder()
config, err = decoder.Bytes(config)
if err != nil {
panic(err)
}
}
var dat map[string]interface{}
if err := hjson.Unmarshal(config, &dat); err != nil {
panic(err)
}
if ep, ok := dat["AdminListen"].(string); ok && (ep != "none" && ep != "") {
endpoint = ep
logger.Println("Found platform default config file", defaults.GetDefaults().DefaultConfigFile)
logger.Println("Using endpoint", endpoint, "from AdminListen")
} else {
logger.Println("Configuration file doesn't contain appropriate AdminListen option")
logger.Println("Falling back to platform default", defaults.GetDefaults().DefaultAdminListen)
}
} else {
logger.Println("Can't open config file from default location", defaults.GetDefaults().DefaultConfigFile)
logger.Println("Falling back to platform default", defaults.GetDefaults().DefaultAdminListen)
}
} else {
endpoint = *server
logger.Println("Using endpoint", endpoint, "from command line")
}
cmdLineEnv.setEndpoint(logger)
var conn net.Conn
u, err := url.Parse(endpoint)
if err == nil {
switch strings.ToLower(u.Scheme) {
case "unix":
logger.Println("Connecting to UNIX socket", endpoint[7:])
conn, err = net.Dial("unix", endpoint[7:])
case "tcp":
logger.Println("Connecting to TCP socket", u.Host)
conn, err = net.Dial("tcp", u.Host)
default:
logger.Println("Unknown protocol or malformed address - check your endpoint")
err = errors.New("protocol not supported")
}
} else {
logger.Println("Connecting to TCP socket", u.Host)
conn, err = net.Dial("tcp", endpoint)
}
if err != nil {
panic(err)
}
conn := connect(cmdLineEnv.endpoint, logger)
logger.Println("Connected")
defer conn.Close()
@ -140,7 +63,7 @@ func run() int {
send := make(admin_info)
recv := make(admin_info)
for c, a := range args {
for c, a := range cmdLineEnv.args {
if c == 0 {
if strings.HasPrefix(a, "-") {
logger.Printf("Ignoring flag %s as it should be specified before other parameters\n", a)
@ -176,7 +99,9 @@ func run() int {
if err := encoder.Encode(&send); err != nil {
panic(err)
}
logger.Printf("Request sent")
if err := decoder.Decode(&recv); err == nil {
logger.Printf("Response received")
if recv["status"] == "error" {
@ -195,20 +120,97 @@ func run() int {
fmt.Println("Missing response body (malformed response?)")
return 1
}
req := recv["request"].(map[string]interface{})
res := recv["response"].(map[string]interface{})
if *injson {
if cmdLineEnv.injson {
if json, err := json.MarshalIndent(res, "", " "); err == nil {
fmt.Println(string(json))
}
return 0
}
handleAll(recv, cmdLineEnv.verbose)
} else {
logger.Println("Error receiving response:", err)
}
if v, ok := recv["status"]; ok && v != "success" {
return 1
}
return 0
}
func connect(endpoint string, logger *log.Logger) net.Conn {
var conn net.Conn
u, err := url.Parse(endpoint)
if err == nil {
switch strings.ToLower(u.Scheme) {
case "unix":
logger.Println("Connecting to UNIX socket", endpoint[7:])
conn, err = net.Dial("unix", endpoint[7:])
case "tcp":
logger.Println("Connecting to TCP socket", u.Host)
conn, err = net.Dial("tcp", u.Host)
default:
logger.Println("Unknown protocol or malformed address - check your endpoint")
err = errors.New("protocol not supported")
}
} else {
logger.Println("Connecting to TCP socket", u.Host)
conn, err = net.Dial("tcp", endpoint)
}
if err != nil {
panic(err)
}
return conn
}
func handleAll(recv map[string]interface{}, verbose bool) {
req := recv["request"].(map[string]interface{})
res := recv["response"].(map[string]interface{})
switch strings.ToLower(req["request"].(string)) {
case "dot":
fmt.Println(res["dot"])
handleDot(res)
case "list", "getpeers", "getswitchpeers", "getdht", "getsessions", "dhtping":
handleVariousInfo(res, verbose)
case "gettuntap", "settuntap":
handleGetAndSetTunTap(res)
case "getself":
handleGetSelf(res, verbose)
case "getswitchqueues":
handleGetSwitchQueues(res)
case "addpeer", "removepeer", "addallowedencryptionpublickey", "removeallowedencryptionpublickey", "addsourcesubnet", "addroute", "removesourcesubnet", "removeroute":
handleAddsAndRemoves(res)
case "getallowedencryptionpublickeys":
handleGetAllowedEncryptionPublicKeys(res)
case "getmulticastinterfaces":
handleGetMulticastInterfaces(res)
case "getsourcesubnets":
handleGetSourceSubnets(res)
case "getroutes":
handleGetRoutes(res)
case "settunnelrouting":
fallthrough
case "gettunnelrouting":
handleGetTunnelRouting(res)
default:
if json, err := json.MarshalIndent(recv["response"], "", " "); err == nil {
fmt.Println(string(json))
}
}
}
func handleDot(res map[string]interface{}) {
fmt.Println(res["dot"])
}
func handleVariousInfo(res map[string]interface{}, verbose bool) {
maxWidths := make(map[string]int)
var keyOrder []string
keysOrdered := false
@ -217,7 +219,7 @@ func run() int {
for slk, slv := range tlv.(map[string]interface{}) {
if !keysOrdered {
for k := range slv.(map[string]interface{}) {
if !*verbose {
if !verbose {
if k == "box_pub_key" || k == "box_sig_key" || k == "nodeinfo" || k == "was_mtu_fixed" {
continue
}
@ -269,7 +271,9 @@ func run() int {
fmt.Println()
}
}
case "gettuntap", "settuntap":
}
func handleGetAndSetTunTap(res map[string]interface{}) {
for k, v := range res {
fmt.Println("Interface name:", k)
if mtu, ok := v.(map[string]interface{})["mtu"].(float64); ok {
@ -279,7 +283,9 @@ func run() int {
fmt.Println("TAP mode:", tap_mode)
}
}
case "getself":
}
func handleGetSelf(res map[string]interface{}, verbose bool) {
for k, v := range res["self"].(map[string]interface{}) {
if buildname, ok := v.(map[string]interface{})["build_name"].(string); ok && buildname != "unknown" {
fmt.Println("Build name:", buildname)
@ -297,7 +303,7 @@ func run() int {
if coords, ok := v.(map[string]interface{})["coords"].(string); ok {
fmt.Println("Coords:", coords)
}
if *verbose {
if verbose {
if nodeID, ok := v.(map[string]interface{})["node_id"].(string); ok {
fmt.Println("Node ID:", nodeID)
}
@ -309,7 +315,9 @@ func run() int {
}
}
}
case "getswitchqueues":
}
func handleGetSwitchQueues(res map[string]interface{}) {
maximumqueuesize := float64(4194304)
portqueues := make(map[float64]float64)
portqueuesize := make(map[float64]float64)
@ -357,7 +365,9 @@ func run() int {
uint(k), uint(v), uint(queuesizepercent), uint(portqueuepackets[k]))
}
}
case "addpeer", "removepeer", "addallowedencryptionpublickey", "removeallowedencryptionpublickey", "addsourcesubnet", "addroute", "removesourcesubnet", "removeroute":
}
func handleAddsAndRemoves(res map[string]interface{}) {
if _, ok := res["added"]; ok {
for _, v := range res["added"].([]interface{}) {
fmt.Println("Added:", fmt.Sprint(v))
@ -378,7 +388,9 @@ func run() int {
fmt.Println("Not removed:", fmt.Sprint(v))
}
}
case "getallowedencryptionpublickeys":
}
func handleGetAllowedEncryptionPublicKeys(res map[string]interface{}) {
if _, ok := res["allowed_box_pubs"]; !ok {
fmt.Println("All connections are allowed")
} else if res["allowed_box_pubs"] == nil {
@ -389,7 +401,9 @@ func run() int {
fmt.Println("-", v)
}
}
case "getmulticastinterfaces":
}
func handleGetMulticastInterfaces(res map[string]interface{}) {
if _, ok := res["multicast_interfaces"]; !ok {
fmt.Println("No multicast interfaces found")
} else if res["multicast_interfaces"] == nil {
@ -400,7 +414,9 @@ func run() int {
fmt.Println("-", v)
}
}
case "getsourcesubnets":
}
func handleGetSourceSubnets(res map[string]interface{}) {
if _, ok := res["source_subnets"]; !ok {
fmt.Println("No source subnets found")
} else if res["source_subnets"] == nil {
@ -411,7 +427,9 @@ func run() int {
fmt.Println("-", v)
}
}
case "getroutes":
}
func handleGetRoutes(res map[string]interface{}) {
if routes, ok := res["routes"].(map[string]interface{}); !ok {
fmt.Println("No routes found")
} else {
@ -426,9 +444,9 @@ func run() int {
}
}
}
case "settunnelrouting":
fallthrough
case "gettunnelrouting":
}
func handleGetTunnelRouting(res map[string]interface{}) {
if enabled, ok := res["enabled"].(bool); !ok {
fmt.Println("Tunnel routing is disabled")
} else if !enabled {
@ -436,17 +454,4 @@ func run() int {
} else {
fmt.Println("Tunnel routing is enabled")
}
default:
if json, err := json.MarshalIndent(recv["response"], "", " "); err == nil {
fmt.Println(string(json))
}
}
} else {
logger.Println("Error receiving response:", err)
}
if v, ok := recv["status"]; ok && v != "success" {
return 1
}
return 0
}