diff --git a/src/yggdrasil/conn.go b/src/yggdrasil/conn.go index b936d4f9..874a7a9c 100644 --- a/src/yggdrasil/conn.go +++ b/src/yggdrasil/conn.go @@ -3,6 +3,7 @@ package yggdrasil import ( "encoding/hex" "errors" + "sync" "sync/atomic" "time" @@ -15,6 +16,7 @@ type Conn struct { nodeID *crypto.NodeID nodeMask *crypto.NodeID session *sessionInfo + sessionMutex *sync.RWMutex readDeadline time.Time writeDeadline time.Time expired bool @@ -28,7 +30,9 @@ func (c *Conn) startSearch() { return } if sinfo != nil { + c.sessionMutex.Lock() c.session = sinfo + c.sessionMutex.Unlock() } } doSearch := func() { @@ -61,15 +65,20 @@ func (c *Conn) startSearch() { } func (c *Conn) Read(b []byte) (int, error) { + c.sessionMutex.RLock() + defer c.sessionMutex.RUnlock() if c.expired { return 0, errors.New("session is closed") } if c.session == nil { return 0, errors.New("searching for remote side") } + c.session.initMutex.RLock() if !c.session.init { + c.session.initMutex.RUnlock() return 0, errors.New("waiting for remote side to accept") } + c.session.initMutex.RUnlock() select { case p, ok := <-c.session.recv: if !ok { @@ -93,7 +102,9 @@ func (c *Conn) Read(b []byte) (int, error) { b = b[:len(bs)] } c.session.updateNonce(&p.Nonce) + c.session.timeMutex.Lock() c.session.time = time.Now() + c.session.timeMutex.Unlock() return nil }() if err != nil { @@ -108,6 +119,8 @@ func (c *Conn) Read(b []byte) (int, error) { } func (c *Conn) Write(b []byte) (bytesWritten int, err error) { + c.sessionMutex.RLock() + defer c.sessionMutex.RUnlock() if c.expired { return 0, errors.New("session is closed") } @@ -118,12 +131,16 @@ func (c *Conn) Write(b []byte) (bytesWritten int, err error) { return 0, errors.New("searching for remote side") } defer util.PutBytes(b) + c.session.initMutex.RLock() if !c.session.init { - // To prevent using empty session keys + c.session.initMutex.RUnlock() return 0, errors.New("waiting for remote side to accept") } + c.session.initMutex.RUnlock() // code isn't multithreaded so appending to this is safe + c.session.coordsMutex.RLock() coords := c.session.coords + c.session.coordsMutex.RUnlock() // Prepare the payload c.session.myNonceMutex.Lock() payload, nonce := crypto.BoxSeal(&c.session.sharedSesKey, b, &c.session.myNonce) diff --git a/src/yggdrasil/core.go b/src/yggdrasil/core.go index b5d74e8c..cfba833b 100644 --- a/src/yggdrasil/core.go +++ b/src/yggdrasil/core.go @@ -5,6 +5,7 @@ import ( "errors" "io/ioutil" "net" + "sync" "time" "github.com/gologme/log" @@ -273,7 +274,9 @@ func (c *Core) ListenConn() (*Listener, error) { // and the second parameter should contain a hexadecimal representation of the // target node ID. func (c *Core) Dial(network, address string) (Conn, error) { - conn := Conn{} + conn := Conn{ + sessionMutex: &sync.RWMutex{}, + } nodeID := crypto.NodeID{} nodeMask := crypto.NodeID{} // Process @@ -298,6 +301,8 @@ func (c *Core) Dial(network, address string) (Conn, error) { conn.core.router.doAdmin(func() { conn.startSearch() }) + conn.sessionMutex.Lock() + defer conn.sessionMutex.Unlock() return conn, nil } diff --git a/src/yggdrasil/router.go b/src/yggdrasil/router.go index d7923f51..693fba44 100644 --- a/src/yggdrasil/router.go +++ b/src/yggdrasil/router.go @@ -291,6 +291,10 @@ func (r *router) sendPacket(bs []byte) { if destSnet.IsValid() { sinfo, isIn = r.core.sessions.getByTheirSubnet(&destSnet) } + sinfo.timeMutex.Lock() + sinfo.initMutex.RLock() + defer sinfo.timeMutex.Unlock() + defer sinfo.initMutex.RUnlock() switch { case !isIn || !sinfo.init: // No or unintiialized session, so we need to search first @@ -306,6 +310,7 @@ func (r *router) sendPacket(bs []byte) { } else { // We haven't heard about the dest in a while now := time.Now() + if !sinfo.time.Before(sinfo.pingTime) { // Update pingTime to start the clock for searches (above) sinfo.pingTime = now @@ -315,6 +320,7 @@ func (r *router) sendPacket(bs []byte) { sinfo.pingSend = now r.core.sessions.sendPingPong(sinfo, false) } + sinfo.timeMutex.Unlock() } fallthrough // Also send the packet default: diff --git a/src/yggdrasil/session.go b/src/yggdrasil/session.go index 40259d9b..9c09532b 100644 --- a/src/yggdrasil/session.go +++ b/src/yggdrasil/session.go @@ -8,6 +8,7 @@ import ( "bytes" "encoding/hex" "sync" + "sync/atomic" "time" "github.com/yggdrasil-network/yggdrasil-go/src/address" @@ -35,21 +36,23 @@ type sessionInfo struct { myNonceMutex sync.Mutex // protects the above theirMTU uint16 myMTU uint16 - wasMTUFixed bool // Was the MTU fixed by a receive error? - time time.Time // Time we last received a packet - coords []byte // coords of destination - packet []byte // a buffered packet, sent immediately on ping/pong - init bool // Reset if coords change + wasMTUFixed bool // Was the MTU fixed by a receive error? + time time.Time // Time we last received a packet + mtuTime time.Time // time myMTU was last changed + pingTime time.Time // time the first ping was sent since the last received packet + pingSend time.Time // time the last ping was sent + timeMutex sync.RWMutex // protects all time fields above + coords []byte // coords of destination + coordsMutex sync.RWMutex // protects the above + packet []byte // a buffered packet, sent immediately on ping/pong + init bool // Reset if coords change + initMutex sync.RWMutex send chan []byte recv chan *wire_trafficPacket closed chan interface{} - tstamp int64 // tstamp from their last session ping, replay attack mitigation - tstampMutex int64 // protects the above - mtuTime time.Time // time myMTU was last changed - pingTime time.Time // time the first ping was sent since the last received packet - pingSend time.Time // time the last ping was sent - bytesSent uint64 // Bytes of real traffic sent in this session - bytesRecvd uint64 // Bytes of real traffic received in this session + tstamp int64 // ATOMIC - tstamp from their last session ping, replay attack mitigation + bytesSent uint64 // Bytes of real traffic sent in this session + bytesRecvd uint64 // Bytes of real traffic received in this session } // Represents a session ping/pong packet, andincludes information like public keys, a session handle, coords, a timestamp to prevent replays, and the tun/tap MTU. @@ -66,7 +69,7 @@ type sessionPing struct { // Updates session info in response to a ping, after checking that the ping is OK. // Returns true if the session was updated, or false otherwise. func (s *sessionInfo) update(p *sessionPing) bool { - if !(p.Tstamp > s.tstamp) { + if !(p.Tstamp > atomic.LoadInt64(&s.tstamp)) { // To protect against replay attacks return false } @@ -90,14 +93,20 @@ func (s *sessionInfo) update(p *sessionPing) bool { s.coords = append(make([]byte, 0, len(p.Coords)+11), p.Coords...) } now := time.Now() + s.timeMutex.Lock() s.time = now - s.tstamp = p.Tstamp + s.timeMutex.Unlock() + atomic.StoreInt64(&s.tstamp, p.Tstamp) + s.initMutex.Lock() s.init = true + s.initMutex.Unlock() return true } // Returns true if the session has been idle for longer than the allowed timeout. func (s *sessionInfo) timedout() bool { + s.timeMutex.RLock() + defer s.timeMutex.RUnlock() return time.Since(s.time) > time.Minute } @@ -284,10 +293,12 @@ func (ss *sessions) createSession(theirPermKey *crypto.BoxPubKey) *sessionInfo { sinfo.myMTU = uint16(ss.core.router.adapter.MTU()) } now := time.Now() + sinfo.timeMutex.Lock() sinfo.time = now sinfo.mtuTime = now sinfo.pingTime = now sinfo.pingSend = now + sinfo.timeMutex.Unlock() higher := false for idx := range ss.core.boxPub { if ss.core.boxPub[idx] > sinfo.theirPermPub[idx] { @@ -428,6 +439,7 @@ func (ss *sessions) sendPingPong(sinfo *sessionInfo, isPong bool) { bs := ping.encode() shared := ss.getSharedKey(&ss.core.boxPriv, &sinfo.theirPermPub) payload, nonce := crypto.BoxSeal(shared, bs, nil) + sinfo.coordsMutex.RLock() p := wire_protoTrafficPacket{ Coords: sinfo.coords, ToKey: sinfo.theirPermPub, @@ -435,10 +447,13 @@ func (ss *sessions) sendPingPong(sinfo *sessionInfo, isPong bool) { Nonce: *nonce, Payload: payload, } + sinfo.coordsMutex.RUnlock() packet := p.encode() ss.core.router.out(packet) if !isPong { + sinfo.timeMutex.Lock() sinfo.pingSend = time.Now() + sinfo.timeMutex.Unlock() } } @@ -465,10 +480,11 @@ func (ss *sessions) handlePing(ping *sessionPing) { ss.listenerMutex.Lock() if ss.listener != nil { conn := &Conn{ - core: ss.core, - session: sinfo, - nodeID: crypto.GetNodeID(&sinfo.theirPermPub), - nodeMask: &crypto.NodeID{}, + core: ss.core, + session: sinfo, + sessionMutex: &sync.RWMutex{}, + nodeID: crypto.GetNodeID(&sinfo.theirPermPub), + nodeMask: &crypto.NodeID{}, } for i := range conn.nodeMask { conn.nodeMask[i] = 0xFF @@ -537,6 +553,8 @@ func (sinfo *sessionInfo) updateNonce(theirNonce *crypto.BoxNonce) { // Called after coord changes, so attemtps to use a session will trigger a new ping and notify the remote end of the coord change. func (ss *sessions) resetInits() { for _, sinfo := range ss.sinfos { + sinfo.initMutex.Lock() sinfo.init = false + sinfo.initMutex.Unlock() } }