if the link handler exits early due to an existing connection, then have it return a channel to that connection which closes when the connection is closed, so we can choose to block on that to avoid spamming connection attempts with dial

This commit is contained in:
Arceliar 2020-12-13 16:29:03 -06:00
parent 1daf3e7bd7
commit a8810c7ee9
2 changed files with 25 additions and 20 deletions

View File

@ -187,7 +187,7 @@ func (l *links) stop() error {
return nil return nil
} }
func (intf *link) handler() error { func (intf *link) handler() (chan struct{}, error) {
// TODO split some of this into shorter functions, so it's easier to read, and for the FIXME duplicate peer issue mentioned later // TODO split some of this into shorter functions, so it's easier to read, and for the FIXME duplicate peer issue mentioned later
go func() { go func() {
for bss := range intf.writer.worker { for bss := range intf.writer.worker {
@ -207,38 +207,38 @@ func (intf *link) handler() error {
// TODO timeouts on send/recv (goroutine for send/recv, channel select w/ timer) // TODO timeouts on send/recv (goroutine for send/recv, channel select w/ timer)
var err error var err error
if !util.FuncTimeout(func() { err = intf.msgIO._sendMetaBytes(metaBytes) }, 30*time.Second) { if !util.FuncTimeout(func() { err = intf.msgIO._sendMetaBytes(metaBytes) }, 30*time.Second) {
return errors.New("timeout on metadata send") return nil, errors.New("timeout on metadata send")
} }
if err != nil { if err != nil {
return err return nil, err
} }
if !util.FuncTimeout(func() { metaBytes, err = intf.msgIO._recvMetaBytes() }, 30*time.Second) { if !util.FuncTimeout(func() { metaBytes, err = intf.msgIO._recvMetaBytes() }, 30*time.Second) {
return errors.New("timeout on metadata recv") return nil, errors.New("timeout on metadata recv")
} }
if err != nil { if err != nil {
return err return nil, err
} }
meta = version_metadata{} meta = version_metadata{}
if !meta.decode(metaBytes) || !meta.check() { if !meta.decode(metaBytes) || !meta.check() {
return errors.New("failed to decode metadata") return nil, errors.New("failed to decode metadata")
} }
base := version_getBaseMetadata() base := version_getBaseMetadata()
if meta.ver > base.ver || meta.ver == base.ver && meta.minorVer > base.minorVer { if meta.ver > base.ver || meta.ver == base.ver && meta.minorVer > base.minorVer {
intf.links.core.log.Errorln("Failed to connect to node: " + intf.lname + " version: " + fmt.Sprintf("%d.%d", meta.ver, meta.minorVer)) intf.links.core.log.Errorln("Failed to connect to node: " + intf.lname + " version: " + fmt.Sprintf("%d.%d", meta.ver, meta.minorVer))
return errors.New("failed to connect: wrong version") return nil, errors.New("failed to connect: wrong version")
} }
// Check if the remote side matches the keys we expected. This is a bit of a weak // Check if the remote side matches the keys we expected. This is a bit of a weak
// check - in future versions we really should check a signature or something like that. // check - in future versions we really should check a signature or something like that.
if pinned := intf.options.pinnedCurve25519Keys; pinned != nil { if pinned := intf.options.pinnedCurve25519Keys; pinned != nil {
if _, allowed := pinned[meta.box]; !allowed { if _, allowed := pinned[meta.box]; !allowed {
intf.links.core.log.Errorf("Failed to connect to node: %q sent curve25519 key that does not match pinned keys", intf.name) intf.links.core.log.Errorf("Failed to connect to node: %q sent curve25519 key that does not match pinned keys", intf.name)
return fmt.Errorf("failed to connect: host sent curve25519 key that does not match pinned keys") return nil, fmt.Errorf("failed to connect: host sent curve25519 key that does not match pinned keys")
} }
} }
if pinned := intf.options.pinnedEd25519Keys; pinned != nil { if pinned := intf.options.pinnedEd25519Keys; pinned != nil {
if _, allowed := pinned[meta.sig]; !allowed { if _, allowed := pinned[meta.sig]; !allowed {
intf.links.core.log.Errorf("Failed to connect to node: %q sent ed25519 key that does not match pinned keys", intf.name) intf.links.core.log.Errorf("Failed to connect to node: %q sent ed25519 key that does not match pinned keys", intf.name)
return fmt.Errorf("failed to connect: host sent ed25519 key that does not match pinned keys") return nil, fmt.Errorf("failed to connect: host sent ed25519 key that does not match pinned keys")
} }
} }
// Check if we're authorized to connect to this key / IP // Check if we're authorized to connect to this key / IP
@ -246,19 +246,19 @@ func (intf *link) handler() error {
intf.links.core.log.Warnf("%s connection from %s forbidden: AllowedEncryptionPublicKeys does not contain key %s", intf.links.core.log.Warnf("%s connection from %s forbidden: AllowedEncryptionPublicKeys does not contain key %s",
strings.ToUpper(intf.info.linkType), intf.info.remote, hex.EncodeToString(meta.box[:])) strings.ToUpper(intf.info.linkType), intf.info.remote, hex.EncodeToString(meta.box[:]))
intf.msgIO.close() intf.msgIO.close()
return nil return nil, nil
} }
// Check if we already have a link to this node // Check if we already have a link to this node
intf.info.box = meta.box intf.info.box = meta.box
intf.info.sig = meta.sig intf.info.sig = meta.sig
intf.links.mutex.Lock() intf.links.mutex.Lock()
if _, isIn := intf.links.links[intf.info]; isIn { if oldIntf, isIn := intf.links.links[intf.info]; isIn {
intf.links.mutex.Unlock() intf.links.mutex.Unlock()
// FIXME we should really return an error and let the caller block instead // FIXME we should really return an error and let the caller block instead
// That lets them do things like close connections on its own, avoid printing a connection message in the first place, etc. // That lets them do things like close connections on its own, avoid printing a connection message in the first place, etc.
intf.links.core.log.Debugln("DEBUG: found existing interface for", intf.name) intf.links.core.log.Debugln("DEBUG: found existing interface for", intf.name)
intf.msgIO.close() intf.msgIO.close()
return nil return oldIntf.closed, nil
} else { } else {
intf.closed = make(chan struct{}) intf.closed = make(chan struct{})
intf.links.links[intf.info] = intf intf.links.links[intf.info] = intf
@ -278,7 +278,7 @@ func (intf *link) handler() error {
intf.peer = intf.links.core.peers._newPeer(&meta.box, &meta.sig, shared, intf) intf.peer = intf.links.core.peers._newPeer(&meta.box, &meta.sig, shared, intf)
}) })
if intf.peer == nil { if intf.peer == nil {
return errors.New("failed to create peer") return nil, errors.New("failed to create peer")
} }
defer func() { defer func() {
// More cleanup can go here // More cleanup can go here
@ -316,7 +316,7 @@ func (intf *link) handler() error {
intf.links.core.log.Infof("Disconnected %s: %s, source %s", intf.links.core.log.Infof("Disconnected %s: %s, source %s",
strings.ToUpper(intf.info.linkType), themString, intf.info.local) strings.ToUpper(intf.info.linkType), themString, intf.info.local)
} }
return err return nil, err
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////

View File

@ -299,7 +299,9 @@ func (t *tcp) call(saddr string, options tcpOptions, sintf string) {
} }
t.waitgroup.Add(1) t.waitgroup.Add(1)
options.socksPeerAddr = conn.RemoteAddr().String() options.socksPeerAddr = conn.RemoteAddr().String()
t.handler(conn, false, options) if ch := t.handler(conn, false, options); ch != nil {
<-ch
}
} else { } else {
dst, err := net.ResolveTCPAddr("tcp", saddr) dst, err := net.ResolveTCPAddr("tcp", saddr)
if err != nil { if err != nil {
@ -365,12 +367,14 @@ func (t *tcp) call(saddr string, options tcpOptions, sintf string) {
return return
} }
t.waitgroup.Add(1) t.waitgroup.Add(1)
t.handler(conn, false, options) if ch := t.handler(conn, false, options); ch != nil {
<-ch
}
} }
}() }()
} }
func (t *tcp) handler(sock net.Conn, incoming bool, options tcpOptions) { func (t *tcp) handler(sock net.Conn, incoming bool, options tcpOptions) chan struct{} {
defer t.waitgroup.Done() // Happens after sock.close defer t.waitgroup.Done() // Happens after sock.close
defer sock.Close() defer sock.Close()
t.setExtraOptions(sock) t.setExtraOptions(sock)
@ -379,7 +383,7 @@ func (t *tcp) handler(sock net.Conn, incoming bool, options tcpOptions) {
var err error var err error
if sock, err = options.upgrade.upgrade(sock); err != nil { if sock, err = options.upgrade.upgrade(sock); err != nil {
t.links.core.log.Errorln("TCP handler upgrade failed:", err) t.links.core.log.Errorln("TCP handler upgrade failed:", err)
return return nil
} }
upgraded = true upgraded = true
} }
@ -415,7 +419,7 @@ func (t *tcp) handler(sock net.Conn, incoming bool, options tcpOptions) {
// Maybe dial/listen at the application level // Maybe dial/listen at the application level
// Then pass a net.Conn to the core library (after these kinds of checks are done) // Then pass a net.Conn to the core library (after these kinds of checks are done)
t.links.core.log.Debugln("Dropping ygg-tunneled connection", local, remote) t.links.core.log.Debugln("Dropping ygg-tunneled connection", local, remote)
return return nil
} }
} }
force := net.ParseIP(strings.Split(remote, "%")[0]).IsLinkLocalUnicast() force := net.ParseIP(strings.Split(remote, "%")[0]).IsLinkLocalUnicast()
@ -425,6 +429,7 @@ func (t *tcp) handler(sock net.Conn, incoming bool, options tcpOptions) {
panic(err) panic(err)
} }
t.links.core.log.Debugln("DEBUG: starting handler for", name) t.links.core.log.Debugln("DEBUG: starting handler for", name)
err = link.handler() ch, err := link.handler()
t.links.core.log.Debugln("DEBUG: stopped handler for", name, err) t.links.core.log.Debugln("DEBUG: stopped handler for", name, err)
return ch
} }