diff --git a/src/tuntap/conn.go b/src/tuntap/conn.go index ce4645f3..66c10119 100644 --- a/src/tuntap/conn.go +++ b/src/tuntap/conn.go @@ -2,6 +2,7 @@ package tuntap import ( "errors" + "time" "github.com/yggdrasil-network/yggdrasil-go/src/address" "github.com/yggdrasil-network/yggdrasil-go/src/util" @@ -9,24 +10,33 @@ import ( ) type tunConn struct { - tun *TunAdapter - conn *yggdrasil.Conn - addr address.Address - snet address.Subnet - send chan []byte - stop chan interface{} + tun *TunAdapter + conn *yggdrasil.Conn + addr address.Address + snet address.Subnet + send chan []byte + stop chan struct{} + alive chan struct{} } func (s *tunConn) close() { s.tun.mutex.Lock() + defer s.tun.mutex.Unlock() s._close_nomutex() - s.tun.mutex.Unlock() } func (s *tunConn) _close_nomutex() { + s.conn.Close() delete(s.tun.addrToConn, s.addr) delete(s.tun.subnetToConn, s.snet) - close(s.stop) + func() { + defer func() { recover() }() + close(s.stop) // Closes reader/writer goroutines + }() + func() { + defer func() { recover() }() + close(s.alive) // Closes timeout goroutine + }() } func (s *tunConn) reader() error { @@ -43,7 +53,7 @@ func (s *tunConn) reader() error { b := make([]byte, 65535) for { go func() { - // TODO read timeout and close + // TODO don't start a new goroutine for every packet read, this is probably a big part of the slowdowns we saw when refactoring if n, err = s.conn.Read(b); err != nil { s.tun.log.Errorln(s.conn.String(), "TUN/TAP conn read error:", err) return @@ -60,6 +70,7 @@ func (s *tunConn) reader() error { util.PutBytes(bs) } } + s.stillAlive() // TODO? Only stay alive if we read >0 bytes? case <-s.stop: s.tun.log.Debugln("Stopping conn reader for", s) return nil @@ -89,6 +100,33 @@ func (s *tunConn) writer() error { s.tun.log.Errorln(s.conn.String(), "TUN/TAP conn write error:", err) } util.PutBytes(b) + s.stillAlive() + } + } +} + +func (s *tunConn) stillAlive() { + select { + case s.alive <- struct{}{}: + default: + } +} + +func (s *tunConn) checkForTimeouts() error { + const timeout = 2 * time.Minute + timer := time.NewTimer(timeout) + defer util.TimerStop(timer) + defer s.close() + for { + select { + case _, ok := <-s.alive: + if !ok { + return errors.New("connection closed") + } + util.TimerStop(timer) + timer.Reset(timeout) + case <-timer.C: + return errors.New("timed out") } } } diff --git a/src/tuntap/tun.go b/src/tuntap/tun.go index b9ff04ec..683b83ac 100644 --- a/src/tuntap/tun.go +++ b/src/tuntap/tun.go @@ -229,10 +229,11 @@ func (tun *TunAdapter) handler() error { func (tun *TunAdapter) wrap(conn *yggdrasil.Conn) (c *tunConn, err error) { // Prepare a session wrapper for the given connection s := tunConn{ - tun: tun, - conn: conn, - send: make(chan []byte, 32), // TODO: is this a sensible value? - stop: make(chan interface{}), + tun: tun, + conn: conn, + send: make(chan []byte, 32), // TODO: is this a sensible value? + stop: make(chan struct{}), + alive: make(chan struct{}, 1), } // Get the remote address and subnet of the other side remoteNodeID := conn.RemoteAddr() @@ -259,6 +260,7 @@ func (tun *TunAdapter) wrap(conn *yggdrasil.Conn) (c *tunConn, err error) { // Start the connection goroutines go s.reader() go s.writer() + go s.checkForTimeouts() // Return return c, err }