From 35ea66d651d3f186d8c33ede4568fd28a18bdc58 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Sat, 22 Oct 2022 17:45:09 +0100 Subject: [PATCH] Varying connection check strictness based on scope --- src/core/link_tcp.go | 30 ++++++++++++++++++++++-------- src/core/link_tls.go | 17 +++++++++-------- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/src/core/link_tcp.go b/src/core/link_tcp.go index ee0dd001..a8f437e9 100644 --- a/src/core/link_tcp.go +++ b/src/core/link_tcp.go @@ -35,14 +35,14 @@ func (l *linkTCP) dial(url *url.URL, options linkOptions, sintf string) error { if err != nil { return err } - info := linkInfoFor("tcp", sintf, addr.String()) - if l.links.isConnectedTo(info) { - return nil - } dialer, err := l.dialerFor(addr, sintf) if err != nil { return err } + info := linkInfoFor("tcp", sintf, tcpIDFor(dialer.LocalAddr, addr)) + if l.links.isConnectedTo(info) { + return nil + } conn, err := dialer.DialContext(l.core.ctx, "tcp", addr.String()) if err != nil { return err @@ -82,10 +82,11 @@ func (l *linkTCP) listen(url *url.URL, sintf string) (*Listener, error) { cancel() break } - addr := conn.RemoteAddr().(*net.TCPAddr) - name := fmt.Sprintf("tcp://%s", addr) - info := linkInfoFor("tcp", sintf, addr.String()) - if err = l.handler(name, info, conn, linkOptions{}, true, addr.IP.IsLinkLocalUnicast()); err != nil { + laddr := conn.LocalAddr().(*net.TCPAddr) + raddr := conn.RemoteAddr().(*net.TCPAddr) + name := fmt.Sprintf("tcp://%s", raddr) + info := linkInfoFor("tcp", sintf, tcpIDFor(laddr, raddr)) + if err = l.handler(name, info, conn, linkOptions{}, true, raddr.IP.IsLinkLocalUnicast()); err != nil { l.core.log.Errorln("Failed to create inbound link:", err) } } @@ -179,3 +180,16 @@ func (l *linkTCP) dialerFor(dst *net.TCPAddr, sintf string) (*net.Dialer, error) } return dialer, nil } + +func tcpIDFor(local net.Addr, remoteAddr *net.TCPAddr) string { + if localAddr, ok := local.(*net.TCPAddr); ok && localAddr.IP.Equal(remoteAddr.IP) { + // Nodes running on the same host — include both the IP and port. + return remoteAddr.String() + } + if remoteAddr.IP.IsLinkLocalUnicast() { + // Nodes discovered via multicast — include the IP only. + return remoteAddr.IP.String() + } + // Nodes connected remotely — include both the IP and port. + return remoteAddr.String() +} diff --git a/src/core/link_tls.go b/src/core/link_tls.go index ee3363ec..3af8fe2b 100644 --- a/src/core/link_tls.go +++ b/src/core/link_tls.go @@ -51,14 +51,14 @@ func (l *linkTLS) dial(url *url.URL, options linkOptions, sintf, sni string) err if err != nil { return err } - info := linkInfoFor("tls", sintf, addr.String()) - if l.links.isConnectedTo(info) { - return nil - } dialer, err := l.tcp.dialerFor(addr, sintf) if err != nil { return err } + info := linkInfoFor("tls", sintf, tcpIDFor(dialer.LocalAddr, addr)) + if l.links.isConnectedTo(info) { + return nil + } tlsconfig := l.config.Clone() tlsconfig.ServerName = sni tlsdialer := &tls.Dialer{ @@ -105,10 +105,11 @@ func (l *linkTLS) listen(url *url.URL, sintf string) (*Listener, error) { cancel() break } - addr := conn.RemoteAddr().(*net.TCPAddr) - name := fmt.Sprintf("tls://%s", addr) - info := linkInfoFor("tls", sintf, addr.String()) - if err = l.handler(name, info, conn, linkOptions{}, true, addr.IP.IsLinkLocalUnicast()); err != nil { + laddr := conn.LocalAddr().(*net.TCPAddr) + raddr := conn.RemoteAddr().(*net.TCPAddr) + name := fmt.Sprintf("tls://%s", raddr) + info := linkInfoFor("tls", sintf, tcpIDFor(laddr, raddr)) + if err = l.handler(name, info, conn, linkOptions{}, true, raddr.IP.IsLinkLocalUnicast()); err != nil { l.core.log.Errorln("Failed to create inbound link:", err) } }