Merge pull request #751 from Arceliar/bugfix

Fix goroutine leak in link.go
This commit is contained in:
Arceliar 2020-12-19 11:04:13 -06:00 committed by GitHub
commit 6eb74a40e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 26 deletions

View File

@ -187,7 +187,7 @@ func (l *links) stop() error {
return nil return nil
} }
func (intf *link) handler() error { func (intf *link) handler() (chan struct{}, error) {
// TODO split some of this into shorter functions, so it's easier to read, and for the FIXME duplicate peer issue mentioned later // TODO split some of this into shorter functions, so it's easier to read, and for the FIXME duplicate peer issue mentioned later
go func() { go func() {
for bss := range intf.writer.worker { for bss := range intf.writer.worker {
@ -207,38 +207,38 @@ func (intf *link) handler() error {
// TODO timeouts on send/recv (goroutine for send/recv, channel select w/ timer) // TODO timeouts on send/recv (goroutine for send/recv, channel select w/ timer)
var err error var err error
if !util.FuncTimeout(func() { err = intf.msgIO._sendMetaBytes(metaBytes) }, 30*time.Second) { if !util.FuncTimeout(func() { err = intf.msgIO._sendMetaBytes(metaBytes) }, 30*time.Second) {
return errors.New("timeout on metadata send") return nil, errors.New("timeout on metadata send")
} }
if err != nil { if err != nil {
return err return nil, err
} }
if !util.FuncTimeout(func() { metaBytes, err = intf.msgIO._recvMetaBytes() }, 30*time.Second) { if !util.FuncTimeout(func() { metaBytes, err = intf.msgIO._recvMetaBytes() }, 30*time.Second) {
return errors.New("timeout on metadata recv") return nil, errors.New("timeout on metadata recv")
} }
if err != nil { if err != nil {
return err return nil, err
} }
meta = version_metadata{} meta = version_metadata{}
if !meta.decode(metaBytes) || !meta.check() { if !meta.decode(metaBytes) || !meta.check() {
return errors.New("failed to decode metadata") return nil, errors.New("failed to decode metadata")
} }
base := version_getBaseMetadata() base := version_getBaseMetadata()
if meta.ver > base.ver || meta.ver == base.ver && meta.minorVer > base.minorVer { if meta.ver > base.ver || meta.ver == base.ver && meta.minorVer > base.minorVer {
intf.links.core.log.Errorln("Failed to connect to node: " + intf.lname + " version: " + fmt.Sprintf("%d.%d", meta.ver, meta.minorVer)) intf.links.core.log.Errorln("Failed to connect to node: " + intf.lname + " version: " + fmt.Sprintf("%d.%d", meta.ver, meta.minorVer))
return errors.New("failed to connect: wrong version") return nil, errors.New("failed to connect: wrong version")
} }
// Check if the remote side matches the keys we expected. This is a bit of a weak // Check if the remote side matches the keys we expected. This is a bit of a weak
// check - in future versions we really should check a signature or something like that. // check - in future versions we really should check a signature or something like that.
if pinned := intf.options.pinnedCurve25519Keys; pinned != nil { if pinned := intf.options.pinnedCurve25519Keys; pinned != nil {
if _, allowed := pinned[meta.box]; !allowed { if _, allowed := pinned[meta.box]; !allowed {
intf.links.core.log.Errorf("Failed to connect to node: %q sent curve25519 key that does not match pinned keys", intf.name) intf.links.core.log.Errorf("Failed to connect to node: %q sent curve25519 key that does not match pinned keys", intf.name)
return fmt.Errorf("failed to connect: host sent curve25519 key that does not match pinned keys") return nil, fmt.Errorf("failed to connect: host sent curve25519 key that does not match pinned keys")
} }
} }
if pinned := intf.options.pinnedEd25519Keys; pinned != nil { if pinned := intf.options.pinnedEd25519Keys; pinned != nil {
if _, allowed := pinned[meta.sig]; !allowed { if _, allowed := pinned[meta.sig]; !allowed {
intf.links.core.log.Errorf("Failed to connect to node: %q sent ed25519 key that does not match pinned keys", intf.name) intf.links.core.log.Errorf("Failed to connect to node: %q sent ed25519 key that does not match pinned keys", intf.name)
return fmt.Errorf("failed to connect: host sent ed25519 key that does not match pinned keys") return nil, fmt.Errorf("failed to connect: host sent ed25519 key that does not match pinned keys")
} }
} }
// Check if we're authorized to connect to this key / IP // Check if we're authorized to connect to this key / IP
@ -246,7 +246,7 @@ func (intf *link) handler() error {
intf.links.core.log.Warnf("%s connection from %s forbidden: AllowedEncryptionPublicKeys does not contain key %s", intf.links.core.log.Warnf("%s connection from %s forbidden: AllowedEncryptionPublicKeys does not contain key %s",
strings.ToUpper(intf.info.linkType), intf.info.remote, hex.EncodeToString(meta.box[:])) strings.ToUpper(intf.info.linkType), intf.info.remote, hex.EncodeToString(meta.box[:]))
intf.msgIO.close() intf.msgIO.close()
return nil return nil, nil
} }
// Check if we already have a link to this node // Check if we already have a link to this node
intf.info.box = meta.box intf.info.box = meta.box
@ -258,11 +258,7 @@ func (intf *link) handler() error {
// That lets them do things like close connections on its own, avoid printing a connection message in the first place, etc. // That lets them do things like close connections on its own, avoid printing a connection message in the first place, etc.
intf.links.core.log.Debugln("DEBUG: found existing interface for", intf.name) intf.links.core.log.Debugln("DEBUG: found existing interface for", intf.name)
intf.msgIO.close() intf.msgIO.close()
if !intf.incoming { return oldIntf.closed, nil
// Block outgoing connection attempts until the existing connection closes
<-oldIntf.closed
}
return nil
} else { } else {
intf.closed = make(chan struct{}) intf.closed = make(chan struct{})
intf.links.links[intf.info] = intf intf.links.links[intf.info] = intf
@ -282,7 +278,7 @@ func (intf *link) handler() error {
intf.peer = intf.links.core.peers._newPeer(&meta.box, &meta.sig, shared, intf) intf.peer = intf.links.core.peers._newPeer(&meta.box, &meta.sig, shared, intf)
}) })
if intf.peer == nil { if intf.peer == nil {
return errors.New("failed to create peer") return nil, errors.New("failed to create peer")
} }
defer func() { defer func() {
// More cleanup can go here // More cleanup can go here
@ -320,7 +316,7 @@ func (intf *link) handler() error {
intf.links.core.log.Infof("Disconnected %s: %s, source %s", intf.links.core.log.Infof("Disconnected %s: %s, source %s",
strings.ToUpper(intf.info.linkType), themString, intf.info.local) strings.ToUpper(intf.info.linkType), themString, intf.info.local)
} }
return err return nil, err
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -460,9 +456,6 @@ func (intf *link) notifyStalled() {
// reset the close timer // reset the close timer
func (intf *link) notifyReading() { func (intf *link) notifyReading() {
intf.Act(&intf.reader, func() { intf.Act(&intf.reader, func() {
if intf.closeTimer != nil {
intf.closeTimer.Stop()
}
intf.closeTimer = time.AfterFunc(closeTime, func() { intf.msgIO.close() }) intf.closeTimer = time.AfterFunc(closeTime, func() { intf.msgIO.close() })
}) })
} }
@ -470,6 +463,7 @@ func (intf *link) notifyReading() {
// wake up the link if it was stalled, and (if size > 0) prepare to send keep-alive traffic // wake up the link if it was stalled, and (if size > 0) prepare to send keep-alive traffic
func (intf *link) notifyRead(size int) { func (intf *link) notifyRead(size int) {
intf.Act(&intf.reader, func() { intf.Act(&intf.reader, func() {
intf.closeTimer.Stop()
if intf.stallTimer != nil { if intf.stallTimer != nil {
intf.stallTimer.Stop() intf.stallTimer.Stop()
intf.stallTimer = nil intf.stallTimer = nil

View File

@ -299,7 +299,9 @@ func (t *tcp) call(saddr string, options tcpOptions, sintf string) {
} }
t.waitgroup.Add(1) t.waitgroup.Add(1)
options.socksPeerAddr = conn.RemoteAddr().String() options.socksPeerAddr = conn.RemoteAddr().String()
t.handler(conn, false, options) if ch := t.handler(conn, false, options); ch != nil {
<-ch
}
} else { } else {
dst, err := net.ResolveTCPAddr("tcp", saddr) dst, err := net.ResolveTCPAddr("tcp", saddr)
if err != nil { if err != nil {
@ -365,12 +367,14 @@ func (t *tcp) call(saddr string, options tcpOptions, sintf string) {
return return
} }
t.waitgroup.Add(1) t.waitgroup.Add(1)
t.handler(conn, false, options) if ch := t.handler(conn, false, options); ch != nil {
<-ch
}
} }
}() }()
} }
func (t *tcp) handler(sock net.Conn, incoming bool, options tcpOptions) { func (t *tcp) handler(sock net.Conn, incoming bool, options tcpOptions) chan struct{} {
defer t.waitgroup.Done() // Happens after sock.close defer t.waitgroup.Done() // Happens after sock.close
defer sock.Close() defer sock.Close()
t.setExtraOptions(sock) t.setExtraOptions(sock)
@ -379,7 +383,7 @@ func (t *tcp) handler(sock net.Conn, incoming bool, options tcpOptions) {
var err error var err error
if sock, err = options.upgrade.upgrade(sock); err != nil { if sock, err = options.upgrade.upgrade(sock); err != nil {
t.links.core.log.Errorln("TCP handler upgrade failed:", err) t.links.core.log.Errorln("TCP handler upgrade failed:", err)
return return nil
} }
upgraded = true upgraded = true
} }
@ -415,7 +419,7 @@ func (t *tcp) handler(sock net.Conn, incoming bool, options tcpOptions) {
// Maybe dial/listen at the application level // Maybe dial/listen at the application level
// Then pass a net.Conn to the core library (after these kinds of checks are done) // Then pass a net.Conn to the core library (after these kinds of checks are done)
t.links.core.log.Debugln("Dropping ygg-tunneled connection", local, remote) t.links.core.log.Debugln("Dropping ygg-tunneled connection", local, remote)
return return nil
} }
} }
force := net.ParseIP(strings.Split(remote, "%")[0]).IsLinkLocalUnicast() force := net.ParseIP(strings.Split(remote, "%")[0]).IsLinkLocalUnicast()
@ -425,6 +429,7 @@ func (t *tcp) handler(sock net.Conn, incoming bool, options tcpOptions) {
panic(err) panic(err)
} }
t.links.core.log.Debugln("DEBUG: starting handler for", name) t.links.core.log.Debugln("DEBUG: starting handler for", name)
err = link.handler() ch, err := link.handler()
t.links.core.log.Debugln("DEBUG: stopped handler for", name, err) t.links.core.log.Debugln("DEBUG: stopped handler for", name, err)
return ch
} }