diff --git a/src/yggdrasil/admin.go b/src/yggdrasil/admin.go index f486468e..a8ac5019 100644 --- a/src/yggdrasil/admin.go +++ b/src/yggdrasil/admin.go @@ -22,6 +22,7 @@ import ( type admin struct { core *Core listenaddr string + listener net.Listener handlers []admin_handlerInfo } @@ -229,17 +230,36 @@ func (a *admin) start() error { return nil } +// cleans up when stopping +func (a *admin) close() error { + return a.listener.Close() +} + // listen is run by start and manages API connections. func (a *admin) listen() { - l, err := net.Listen("tcp", a.listenaddr) + u, err := url.Parse(a.listenaddr) + if err == nil { + switch strings.ToLower(u.Scheme) { + case "unix": + a.listener, err = net.Listen("unix", a.listenaddr[7:]) + case "tcp": + a.listener, err = net.Listen("tcp", u.Host) + default: + err = errors.New("protocol not supported") + } + } else { + a.listener, err = net.Listen("tcp", a.listenaddr) + } if err != nil { a.core.log.Printf("Admin socket failed to listen: %v", err) os.Exit(1) } - defer l.Close() - a.core.log.Printf("Admin socket listening on %s", l.Addr().String()) + a.core.log.Printf("%s admin socket listening on %s", + strings.ToUpper(a.listener.Addr().Network()), + a.listener.Addr().String()) + defer a.listener.Close() for { - conn, err := l.Accept() + conn, err := a.listener.Accept() if err == nil { a.handleRequest(conn) } diff --git a/src/yggdrasil/core.go b/src/yggdrasil/core.go index 28ca8f30..a0d5a118 100644 --- a/src/yggdrasil/core.go +++ b/src/yggdrasil/core.go @@ -136,6 +136,7 @@ func (c *Core) Start(nc *config.NodeConfig, log *log.Logger) error { func (c *Core) Stop() { c.log.Println("Stopping...") c.tun.close() + c.admin.close() } // Generates a new encryption keypair. The encryption keys are used to diff --git a/yggdrasilctl.go b/yggdrasilctl.go index c4efe773..6281b162 100644 --- a/yggdrasilctl.go +++ b/yggdrasilctl.go @@ -1,9 +1,11 @@ package main +import "errors" import "flag" import "fmt" import "strings" import "net" +import "net/url" import "sort" import "encoding/json" import "strconv" @@ -20,14 +22,28 @@ func main() { args := flag.Args() if len(args) == 0 { - fmt.Println("usage:", os.Args[0], "[-endpoint=localhost:9001] [-json] command [key=value] [...]") + fmt.Println("usage:", os.Args[0], "[-endpoint=proto://server] [-json] command [key=value] [...]") fmt.Println("example:", os.Args[0], "getPeers") fmt.Println("example:", os.Args[0], "setTunTap name=auto mtu=1500 tap_mode=false") - fmt.Println("example:", os.Args[0], "-endpoint=localhost:9001 getDHT") + fmt.Println("example:", os.Args[0], "-endpoint=tcp://localhost:9001 getDHT") + fmt.Println("example:", os.Args[0], "-endpoint=unix:///var/run/ygg.sock getDHT") return } - conn, err := net.Dial("tcp", *server) + var conn net.Conn + u, err := url.Parse(*server) + if err == nil { + switch strings.ToLower(u.Scheme) { + case "unix": + conn, err = net.Dial("unix", (*server)[7:]) + case "tcp": + conn, err = net.Dial("tcp", u.Host) + default: + err = errors.New("protocol not supported") + } + } else { + conn, err = net.Dial("tcp", *server) + } if err != nil { panic(err) }