Rename awdl.go to link.go, add stream.go, update tcp.go

This commit is contained in:
Neil Alexander 2019-01-19 00:14:10 +00:00
parent f6cb194d5c
commit 6fe3b01e90
No known key found for this signature in database
GPG Key ID: A02A2019A2BB0944
4 changed files with 177 additions and 129 deletions

View File

@ -44,7 +44,7 @@ type Core struct {
searches searches searches searches
multicast multicast multicast multicast
tcp tcpInterface tcp tcpInterface
awdl awdl link link
log *log.Logger log *log.Logger
} }
@ -198,10 +198,12 @@ func (c *Core) Start(nc *config.NodeConfig, log *log.Logger) error {
return err return err
} }
if err := c.awdl.init(c); err != nil { /*
c.log.Println("Failed to start AWDL interface") if err := c.awdl.init(c); err != nil {
return err c.log.Println("Failed to start AWDL interface")
} return err
}
*/
if nc.SwitchOptions.MaxTotalQueueSize >= SwitchQueueTotalMinSize { if nc.SwitchOptions.MaxTotalQueueSize >= SwitchQueueTotalMinSize {
c.switchTable.queueTotalMaxSize = nc.SwitchOptions.MaxTotalQueueSize c.switchTable.queueTotalMaxSize = nc.SwitchOptions.MaxTotalQueueSize

View File

@ -11,34 +11,35 @@ import (
"github.com/yggdrasil-network/yggdrasil-go/src/util" "github.com/yggdrasil-network/yggdrasil-go/src/util"
) )
type awdl struct { type link struct {
core *Core core *Core
mutex sync.RWMutex // protects interfaces below mutex sync.RWMutex // protects interfaces below
interfaces map[string]*awdlInterface interfaces map[string]*linkInterface
} }
type awdlInterface struct { type linkInterface struct {
awdl *awdl link *link
fromAWDL chan []byte fromlink chan []byte
toAWDL chan []byte tolink chan []byte
shutdown chan bool shutdown chan bool
peer *peer peer *peer
stream stream
} }
func (l *awdl) init(c *Core) error { func (l *link) init(c *Core) error {
l.core = c l.core = c
l.mutex.Lock() l.mutex.Lock()
l.interfaces = make(map[string]*awdlInterface) l.interfaces = make(map[string]*linkInterface)
l.mutex.Unlock() l.mutex.Unlock()
return nil return nil
} }
func (l *awdl) create(fromAWDL chan []byte, toAWDL chan []byte /*boxPubKey *crypto.BoxPubKey, sigPubKey *crypto.SigPubKey*/, name string) (*awdlInterface, error) { func (l *link) create(fromlink chan []byte, tolink chan []byte /*boxPubKey *crypto.BoxPubKey, sigPubKey *crypto.SigPubKey*/, name string) (*linkInterface, error) {
intf := awdlInterface{ intf := linkInterface{
awdl: l, link: l,
fromAWDL: fromAWDL, fromlink: fromlink,
toAWDL: toAWDL, tolink: tolink,
shutdown: make(chan bool), shutdown: make(chan bool),
} }
l.mutex.Lock() l.mutex.Lock()
@ -50,35 +51,29 @@ func (l *awdl) create(fromAWDL chan []byte, toAWDL chan []byte /*boxPubKey *cryp
meta.sig = l.core.sigPub meta.sig = l.core.sigPub
meta.link = *myLinkPub meta.link = *myLinkPub
metaBytes := meta.encode() metaBytes := meta.encode()
l.core.log.Println("toAWDL <- metaBytes") tolink <- metaBytes
toAWDL <- metaBytes metaBytes = <-fromlink
l.core.log.Println("metaBytes = <-fromAWDL")
metaBytes = <-fromAWDL
l.core.log.Println("version_metadata{}")
meta = version_metadata{} meta = version_metadata{}
if !meta.decode(metaBytes) || !meta.check() { if !meta.decode(metaBytes) || !meta.check() {
return nil, errors.New("Metadata decode failure") return nil, errors.New("Metadata decode failure")
} }
l.core.log.Println("version_getBaseMetadata{}")
base := version_getBaseMetadata() base := version_getBaseMetadata()
if meta.ver > base.ver || meta.ver == base.ver && meta.minorVer > base.minorVer { if meta.ver > base.ver || meta.ver == base.ver && meta.minorVer > base.minorVer {
return nil, errors.New("Failed to connect to node: " + name + " version: " + fmt.Sprintf("%d.%d", meta.ver, meta.minorVer)) return nil, errors.New("Failed to connect to node: " + name + " version: " + fmt.Sprintf("%d.%d", meta.ver, meta.minorVer))
} }
l.core.log.Println("crypto.GetSharedKey")
shared := crypto.GetSharedKey(myLinkPriv, &meta.link) shared := crypto.GetSharedKey(myLinkPriv, &meta.link)
//shared := crypto.GetSharedKey(&l.core.boxPriv, boxPubKey) //shared := crypto.GetSharedKey(&l.core.boxPriv, boxPubKey)
l.core.log.Println("l.core.peers.newPeer")
intf.peer = l.core.peers.newPeer(&meta.box, &meta.sig, shared, name) intf.peer = l.core.peers.newPeer(&meta.box, &meta.sig, shared, name)
if intf.peer != nil { if intf.peer != nil {
intf.peer.linkOut = make(chan []byte, 1) // protocol traffic intf.peer.linkOut = make(chan []byte, 1) // protocol traffic
intf.peer.out = func(msg []byte) { intf.peer.out = func(msg []byte) {
defer func() { recover() }() defer func() { recover() }()
intf.toAWDL <- msg intf.tolink <- msg
} // called by peer.sendPacket() } // called by peer.sendPacket()
l.core.switchTable.idleIn <- intf.peer.port // notify switch that we're idle l.core.switchTable.idleIn <- intf.peer.port // notify switch that we're idle
intf.peer.close = func() { intf.peer.close = func() {
close(intf.fromAWDL) close(intf.fromlink)
close(intf.toAWDL) close(intf.tolink)
} }
go intf.handler() go intf.handler()
go intf.peer.linkLoop() go intf.peer.linkLoop()
@ -88,7 +83,7 @@ func (l *awdl) create(fromAWDL chan []byte, toAWDL chan []byte /*boxPubKey *cryp
return nil, errors.New("l.core.peers.newPeer failed") return nil, errors.New("l.core.peers.newPeer failed")
} }
func (l *awdl) getInterface(identity string) *awdlInterface { func (l *link) getInterface(identity string) *linkInterface {
l.mutex.RLock() l.mutex.RLock()
defer l.mutex.RUnlock() defer l.mutex.RUnlock()
if intf, ok := l.interfaces[identity]; ok { if intf, ok := l.interfaces[identity]; ok {
@ -97,7 +92,7 @@ func (l *awdl) getInterface(identity string) *awdlInterface {
return nil return nil
} }
func (l *awdl) shutdown(identity string) error { func (l *link) shutdown(identity string) error {
if intf, ok := l.interfaces[identity]; ok { if intf, ok := l.interfaces[identity]; ok {
intf.shutdown <- true intf.shutdown <- true
l.core.peers.removePeer(intf.peer.port) l.core.peers.removePeer(intf.peer.port)
@ -110,9 +105,9 @@ func (l *awdl) shutdown(identity string) error {
} }
} }
func (ai *awdlInterface) handler() { func (ai *linkInterface) handler() {
send := func(msg []byte) { send := func(msg []byte) {
ai.toAWDL <- msg ai.tolink <- msg
atomic.AddUint64(&ai.peer.bytesSent, uint64(len(msg))) atomic.AddUint64(&ai.peer.bytesSent, uint64(len(msg)))
util.PutBytes(msg) util.PutBytes(msg)
} }
@ -138,9 +133,9 @@ func (ai *awdlInterface) handler() {
case p := <-ai.peer.linkOut: case p := <-ai.peer.linkOut:
send(p) send(p)
continue continue
case r := <-ai.fromAWDL: case r := <-ai.fromlink:
ai.peer.handlePacket(r) ai.peer.handlePacket(r)
ai.awdl.core.switchTable.idleIn <- ai.peer.port ai.link.core.switchTable.idleIn <- ai.peer.port
case <-ai.shutdown: case <-ai.shutdown:
return return
} }

111
src/yggdrasil/stream.go Normal file
View File

@ -0,0 +1,111 @@
package yggdrasil
import (
"errors"
"fmt"
"github.com/yggdrasil-network/yggdrasil-go/src/util"
)
type stream struct {
buffer []byte
cursor int
}
const streamMsgSize = 2048 + 65535
var streamMsg = [...]byte{0xde, 0xad, 0xb1, 0x75} // "dead bits"
func (s *stream) init() {
s.buffer = make([]byte, 2*streamMsgSize)
s.cursor = 0
}
// This reads from the channel into a []byte buffer for incoming messages. It
// copies completed messages out of the cache into a new slice, and passes them
// to the peer struct via the provided `in func([]byte)` argument. Then it
// shifts the incomplete fragments of data forward so future reads won't
// overwrite it.
func (s *stream) write(bs []byte, in func([]byte)) error {
frag := s.buffer[:0]
if n := len(bs); n > 0 {
frag = append(frag, bs[:n]...)
msg, ok, err2 := stream_chopMsg(&frag)
if err2 != nil {
return fmt.Errorf("message error: %v", err2)
}
if !ok {
// We didn't get the whole message yet
return nil
}
newMsg := append(util.GetBytes(), msg...)
in(newMsg)
util.Yield()
}
return nil
}
// This takes a pointer to a slice as an argument. It checks if there's a
// complete message and, if so, slices out those parts and returns the message,
// true, and nil. If there's no error, but also no complete message, it returns
// nil, false, and nil. If there's an error, it returns nil, false, and the
// error, which the reader then handles (currently, by returning from the
// reader, which causes the connection to close).
func stream_chopMsg(bs *[]byte) ([]byte, bool, error) {
// Returns msg, ok, err
if len(*bs) < len(streamMsg) {
return nil, false, nil
}
for idx := range streamMsg {
if (*bs)[idx] != streamMsg[idx] {
return nil, false, errors.New("bad message")
}
}
msgLen, msgLenLen := wire_decode_uint64((*bs)[len(streamMsg):])
if msgLen > streamMsgSize {
return nil, false, errors.New("oversized message")
}
msgBegin := len(streamMsg) + msgLenLen
msgEnd := msgBegin + int(msgLen)
if msgLenLen == 0 || len(*bs) < msgEnd {
// We don't have the full message
// Need to buffer this and wait for the rest to come in
return nil, false, nil
}
msg := (*bs)[msgBegin:msgEnd]
(*bs) = (*bs)[msgEnd:]
return msg, true, nil
}
/*
func (s *stream) chopMsg() ([]byte, bool, error) {
// Returns msg, ok, err
if len(s.buffer) < len(streamMsg) {
fmt.Println("*** too short")
return nil, false, nil
}
for idx := range streamMsg {
if s.buffer[idx] != streamMsg[idx] {
fmt.Println("*** bad message")
return nil, false, errors.New("bad message")
}
}
msgLen, msgLenLen := wire_decode_uint64((s.buffer)[len(streamMsg):])
if msgLen > streamMsgSize {
fmt.Println("*** oversized message")
return nil, false, errors.New("oversized message")
}
msgBegin := len(streamMsg) + msgLenLen
msgEnd := msgBegin + int(msgLen)
if msgLenLen == 0 || len(s.buffer) < msgEnd {
// We don't have the full message
// Need to buffer this and wait for the rest to come in
fmt.Println("*** still waiting")
return nil, false, nil
}
msg := s.buffer[msgBegin:msgEnd]
s.buffer = s.buffer[msgEnd:]
fmt.Println("*** done")
return msg, true, nil
}
*/

View File

@ -16,9 +16,7 @@ package yggdrasil
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io"
"math/rand" "math/rand"
"net" "net"
"sync" "sync"
@ -32,21 +30,21 @@ import (
"github.com/yggdrasil-network/yggdrasil-go/src/util" "github.com/yggdrasil-network/yggdrasil-go/src/util"
) )
const tcp_msgSize = 2048 + 65535 // TODO figure out what makes sense const default_timeout = 6 * time.Second
const default_tcp_timeout = 6 * time.Second const tcp_ping_interval = (default_timeout * 2 / 3)
const tcp_ping_interval = (default_tcp_timeout * 2 / 3)
// The TCP listener and information about active TCP connections, to avoid duplication. // The TCP listener and information about active TCP connections, to avoid duplication.
type tcpInterface struct { type tcpInterface struct {
core *Core core *Core
reconfigure chan chan error reconfigure chan chan error
serv net.Listener serv net.Listener
serv_stop chan bool stop chan bool
tcp_timeout time.Duration timeout time.Duration
tcp_addr string addr string
mutex sync.Mutex // Protecting the below mutex sync.Mutex // Protecting the below
calls map[string]struct{} calls map[string]struct{}
conns map[tcpInfo](chan struct{}) conns map[tcpInfo](chan struct{})
stream stream
} }
// This is used as the key to a map that tracks existing connections, to prevent multiple connections to the same keys and local/remote address pair from occuring. // This is used as the key to a map that tracks existing connections, to prevent multiple connections to the same keys and local/remote address pair from occuring.
@ -86,7 +84,7 @@ func (iface *tcpInterface) connectSOCKS(socksaddr, peeraddr string) {
// Initializes the struct. // Initializes the struct.
func (iface *tcpInterface) init(core *Core) (err error) { func (iface *tcpInterface) init(core *Core) (err error) {
iface.core = core iface.core = core
iface.serv_stop = make(chan bool, 1) iface.stop = make(chan bool, 1)
iface.reconfigure = make(chan chan error, 1) iface.reconfigure = make(chan chan error, 1)
go func() { go func() {
for { for {
@ -95,7 +93,7 @@ func (iface *tcpInterface) init(core *Core) (err error) {
updated := iface.core.config.Listen != iface.core.configOld.Listen updated := iface.core.config.Listen != iface.core.configOld.Listen
iface.core.configMutex.RUnlock() iface.core.configMutex.RUnlock()
if updated { if updated {
iface.serv_stop <- true iface.stop <- true
iface.serv.Close() iface.serv.Close()
e <- iface.listen() e <- iface.listen()
} else { } else {
@ -111,19 +109,19 @@ func (iface *tcpInterface) listen() error {
var err error var err error
iface.core.configMutex.RLock() iface.core.configMutex.RLock()
iface.tcp_addr = iface.core.config.Listen iface.addr = iface.core.config.Listen
iface.tcp_timeout = time.Duration(iface.core.config.ReadTimeout) * time.Millisecond iface.timeout = time.Duration(iface.core.config.ReadTimeout) * time.Millisecond
iface.core.configMutex.RUnlock() iface.core.configMutex.RUnlock()
if iface.tcp_timeout >= 0 && iface.tcp_timeout < default_tcp_timeout { if iface.timeout >= 0 && iface.timeout < default_timeout {
iface.tcp_timeout = default_tcp_timeout iface.timeout = default_timeout
} }
ctx := context.Background() ctx := context.Background()
lc := net.ListenConfig{ lc := net.ListenConfig{
Control: iface.tcpContext, Control: iface.tcpContext,
} }
iface.serv, err = lc.Listen(ctx, "tcp", iface.tcp_addr) iface.serv, err = lc.Listen(ctx, "tcp", iface.addr)
if err == nil { if err == nil {
iface.mutex.Lock() iface.mutex.Lock()
iface.calls = make(map[string]struct{}) iface.calls = make(map[string]struct{})
@ -147,7 +145,7 @@ func (iface *tcpInterface) listener() {
return return
} }
select { select {
case <-iface.serv_stop: case <-iface.stop:
iface.core.log.Println("Stopping listener") iface.core.log.Println("Stopping listener")
return return
default: default:
@ -194,7 +192,7 @@ func (iface *tcpInterface) call(saddr string, socksaddr *string, sintf string) {
iface.mutex.Unlock() iface.mutex.Unlock()
defer func() { defer func() {
// Block new calls for a little while, to mitigate livelock scenarios // Block new calls for a little while, to mitigate livelock scenarios
time.Sleep(default_tcp_timeout) time.Sleep(default_timeout)
time.Sleep(time.Duration(rand.Intn(1000)) * time.Millisecond) time.Sleep(time.Duration(rand.Intn(1000)) * time.Millisecond)
iface.mutex.Lock() iface.mutex.Lock()
delete(iface.calls, callname) delete(iface.calls, callname)
@ -299,8 +297,8 @@ func (iface *tcpInterface) handler(sock net.Conn, incoming bool) {
if err != nil { if err != nil {
return return
} }
if iface.tcp_timeout > 0 { if iface.timeout > 0 {
sock.SetReadDeadline(time.Now().Add(iface.tcp_timeout)) sock.SetReadDeadline(time.Now().Add(iface.timeout))
} }
_, err = sock.Read(metaBytes) _, err = sock.Read(metaBytes)
if err != nil { if err != nil {
@ -389,9 +387,9 @@ func (iface *tcpInterface) handler(sock net.Conn, incoming bool) {
// This goroutine waits for outgoing packets, link protocol traffic, or sends idle keep-alive traffic // This goroutine waits for outgoing packets, link protocol traffic, or sends idle keep-alive traffic
send := func(msg []byte) { send := func(msg []byte) {
msgLen := wire_encode_uint64(uint64(len(msg))) msgLen := wire_encode_uint64(uint64(len(msg)))
buf := net.Buffers{tcp_msg[:], msgLen, msg} buf := net.Buffers{streamMsg[:], msgLen, msg}
buf.WriteTo(sock) buf.WriteTo(sock)
atomic.AddUint64(&p.bytesSent, uint64(len(tcp_msg)+len(msgLen)+len(msg))) atomic.AddUint64(&p.bytesSent, uint64(len(streamMsg)+len(msgLen)+len(msg)))
util.PutBytes(msg) util.PutBytes(msg)
} }
timerInterval := tcp_ping_interval timerInterval := tcp_ping_interval
@ -445,7 +443,21 @@ func (iface *tcpInterface) handler(sock net.Conn, incoming bool) {
themAddrString := net.IP(themAddr[:]).String() themAddrString := net.IP(themAddr[:]).String()
themString := fmt.Sprintf("%s@%s", themAddrString, them) themString := fmt.Sprintf("%s@%s", themAddrString, them)
iface.core.log.Printf("Connected: %s, source: %s", themString, us) iface.core.log.Printf("Connected: %s, source: %s", themString, us)
err = iface.reader(sock, in) // In this goroutine, because of defers iface.stream.init()
bs := make([]byte, 2*streamMsgSize)
var n int
for {
if iface.timeout > 0 {
sock.SetReadDeadline(time.Now().Add(iface.timeout))
}
n, err = sock.Read(bs)
if err != nil {
break
}
if n > 0 {
iface.stream.write(bs[:n], in)
}
}
if err == nil { if err == nil {
iface.core.log.Printf("Disconnected: %s, source: %s", themString, us) iface.core.log.Printf("Disconnected: %s, source: %s", themString, us)
} else { } else {
@ -453,75 +465,3 @@ func (iface *tcpInterface) handler(sock net.Conn, incoming bool) {
} }
return return
} }
// This reads from the socket into a []byte buffer for incomping messages.
// It copies completed messages out of the cache into a new slice, and passes them to the peer struct via the provided `in func([]byte)` argument.
// Then it shifts the incomplete fragments of data forward so future reads won't overwrite it.
func (iface *tcpInterface) reader(sock net.Conn, in func([]byte)) error {
bs := make([]byte, 2*tcp_msgSize)
frag := bs[:0]
for {
if iface.tcp_timeout > 0 {
sock.SetReadDeadline(time.Now().Add(iface.tcp_timeout))
}
n, err := sock.Read(bs[len(frag):])
if n > 0 {
frag = bs[:len(frag)+n]
for {
msg, ok, err2 := tcp_chop_msg(&frag)
if err2 != nil {
return fmt.Errorf("Message error: %v", err2)
}
if !ok {
// We didn't get the whole message yet
break
}
newMsg := append(util.GetBytes(), msg...)
in(newMsg)
util.Yield()
}
frag = append(bs[:0], frag...)
}
if err != nil || n == 0 {
if err != io.EOF {
return err
}
return nil
}
}
}
////////////////////////////////////////////////////////////////////////////////
// These are 4 bytes of padding used to catch if something went horribly wrong with the tcp connection.
var tcp_msg = [...]byte{0xde, 0xad, 0xb1, 0x75} // "dead bits"
// This takes a pointer to a slice as an argument.
// It checks if there's a complete message and, if so, slices out those parts and returns the message, true, and nil.
// If there's no error, but also no complete message, it returns nil, false, and nil.
// If there's an error, it returns nil, false, and the error, which the reader then handles (currently, by returning from the reader, which causes the connection to close).
func tcp_chop_msg(bs *[]byte) ([]byte, bool, error) {
// Returns msg, ok, err
if len(*bs) < len(tcp_msg) {
return nil, false, nil
}
for idx := range tcp_msg {
if (*bs)[idx] != tcp_msg[idx] {
return nil, false, errors.New("Bad message!")
}
}
msgLen, msgLenLen := wire_decode_uint64((*bs)[len(tcp_msg):])
if msgLen > tcp_msgSize {
return nil, false, errors.New("Oversized message!")
}
msgBegin := len(tcp_msg) + msgLenLen
msgEnd := msgBegin + int(msgLen)
if msgLenLen == 0 || len(*bs) < msgEnd {
// We don't have the full message
// Need to buffer this and wait for the rest to come in
return nil, false, nil
}
msg := (*bs)[msgBegin:msgEnd]
(*bs) = (*bs)[msgEnd:]
return msg, true, nil
}