diff --git a/src/yggdrasil/conn.go b/src/yggdrasil/conn.go index 7aa46554..fd65743c 100644 --- a/src/yggdrasil/conn.go +++ b/src/yggdrasil/conn.go @@ -3,6 +3,7 @@ package yggdrasil import ( "encoding/hex" "errors" + "sync/atomic" "time" "github.com/yggdrasil-network/yggdrasil-go/src/crypto" @@ -84,7 +85,7 @@ func (c *Conn) Read(b []byte) (int, error) { b = append(b, bs...) c.session.updateNonce(&p.Nonce) c.session.time = time.Now() - c.session.bytesRecvd += uint64(len(bs)) + atomic.AddUint64(&c.session.bytesRecvd, uint64(len(b))) return len(b), nil case <-c.session.closed: return len(b), errors.New("session was closed") @@ -106,7 +107,9 @@ func (c *Conn) Write(b []byte) (bytesWritten int, err error) { // code isn't multithreaded so appending to this is safe coords := c.session.coords // Prepare the payload + c.session.myNonceMutex.Lock() payload, nonce := crypto.BoxSeal(&c.session.sharedSesKey, b, &c.session.myNonce) + c.session.myNonceMutex.Unlock() defer util.PutBytes(payload) p := wire_trafficPacket{ Coords: coords, @@ -115,7 +118,7 @@ func (c *Conn) Write(b []byte) (bytesWritten int, err error) { Payload: payload, } packet := p.encode() - c.session.bytesSent += uint64(len(b)) + atomic.AddUint64(&c.session.bytesSent, uint64(len(b))) select { case c.session.send <- packet: case <-c.session.closed: diff --git a/src/yggdrasil/session.go b/src/yggdrasil/session.go index e860dc81..e46761d5 100644 --- a/src/yggdrasil/session.go +++ b/src/yggdrasil/session.go @@ -29,9 +29,10 @@ type sessionInfo struct { theirHandle crypto.Handle myHandle crypto.Handle theirNonce crypto.BoxNonce - theirNonceMutex sync.RWMutex // protects the above + theirNonceMask uint64 + theirNonceMutex sync.Mutex // protects the above myNonce crypto.BoxNonce - myNonceMutex sync.RWMutex // protects the above + myNonceMutex sync.Mutex // protects the above theirMTU uint16 myMTU uint16 wasMTUFixed bool // Was the MTU fixed by a receive error? @@ -42,7 +43,6 @@ type sessionInfo struct { send chan []byte recv chan *wire_trafficPacket closed chan interface{} - nonceMask uint64 tstamp int64 // tstamp from their last session ping, replay attack mitigation tstampMutex int64 // protects the above mtuTime time.Time // time myMTU was last changed @@ -79,8 +79,10 @@ func (s *sessionInfo) update(p *sessionPing) bool { s.theirSesPub = p.SendSesPub s.theirHandle = p.Handle s.sharedSesKey = *crypto.GetSharedKey(&s.mySesPriv, &s.theirSesPub) + s.theirNonceMutex.Lock() s.theirNonce = crypto.BoxNonce{} - s.nonceMask = 0 + s.theirNonceMask = 0 + s.theirNonceMutex.Unlock() } if p.MTU >= 1280 || p.MTU == 0 { s.theirMTU = p.MTU @@ -270,6 +272,10 @@ func (ss *sessions) createSession(theirPermKey *crypto.BoxPubKey) *sessionInfo { return nil } sinfo := sessionInfo{} + sinfo.myNonceMutex.Lock() + sinfo.theirNonceMutex.Lock() + defer sinfo.myNonceMutex.Unlock() + defer sinfo.theirNonceMutex.Unlock() sinfo.core = ss.core sinfo.reconfigure = make(chan chan error, 1) sinfo.theirPermPub = *theirPermKey @@ -389,7 +395,9 @@ func (ss *sessions) getPing(sinfo *sessionInfo) sessionPing { Coords: coords, MTU: sinfo.myMTU, } + sinfo.myNonceMutex.Lock() sinfo.myNonce.Increment() + sinfo.myNonceMutex.Unlock() return ref } @@ -493,26 +501,30 @@ func (sinfo *sessionInfo) getMTU() uint16 { // Checks if a packet's nonce is recent enough to fall within the window of allowed packets, and not already received. func (sinfo *sessionInfo) nonceIsOK(theirNonce *crypto.BoxNonce) bool { // The bitmask is to allow for some non-duplicate out-of-order packets + sinfo.theirNonceMutex.Lock() + defer sinfo.theirNonceMutex.Unlock() diff := theirNonce.Minus(&sinfo.theirNonce) if diff > 0 { return true } - return ^sinfo.nonceMask&(0x01< 0 { // This nonce is newer, so shift the window before setting the bit, and update theirNonce in the session info. - sinfo.nonceMask <<= uint64(diff) - sinfo.nonceMask &= 0x01 + sinfo.theirNonceMask <<= uint64(diff) + sinfo.theirNonceMask &= 0x01 sinfo.theirNonce = *theirNonce } else { // This nonce is older, so set the bit but do not shift the window. - sinfo.nonceMask &= 0x01 << uint64(-diff) + sinfo.theirNonceMask &= 0x01 << uint64(-diff) } }