Try all addresses when connecting to a DNS name

Fixes #980
This commit is contained in:
Neil Alexander 2022-11-08 21:59:13 +00:00
parent 6112c9cf18
commit 110613b234
2 changed files with 64 additions and 47 deletions

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/url" "net/url"
"strconv"
"strings" "strings"
"time" "time"
@ -31,24 +32,39 @@ func (l *links) newLinkTCP() *linkTCP {
} }
func (l *linkTCP) dial(url *url.URL, options linkOptions, sintf string) error { func (l *linkTCP) dial(url *url.URL, options linkOptions, sintf string) error {
addr, err := net.ResolveTCPAddr("tcp", url.Host) info := linkInfoFor("tcp", sintf, url.Host)
if err != nil {
return err
}
dialer, err := l.dialerFor(addr, sintf)
if err != nil {
return err
}
info := linkInfoFor("tcp", sintf, tcpIDFor(dialer.LocalAddr, addr))
if l.links.isConnectedTo(info) { if l.links.isConnectedTo(info) {
return nil return nil
} }
conn, err := dialer.DialContext(l.core.ctx, "tcp", addr.String()) host, p, err := net.SplitHostPort(url.Host)
if err != nil { if err != nil {
return err return err
} }
port, err := strconv.Atoi(p)
if err != nil {
return err
}
ips, err := net.LookupIP(host)
if err != nil {
return err
}
for _, ip := range ips {
addr := &net.TCPAddr{
IP: ip,
Port: port,
}
dialer, err := l.dialerFor(addr, sintf)
if err != nil {
continue
}
conn, err := dialer.DialContext(l.core.ctx, "tcp", addr.String())
if err != nil {
continue
}
uri := strings.TrimRight(strings.SplitN(url.String(), "?", 2)[0], "/") uri := strings.TrimRight(strings.SplitN(url.String(), "?", 2)[0], "/")
return l.handler(uri, info, conn, options, false, false) return l.handler(uri, info, conn, options, false, false)
}
return fmt.Errorf("failed to connect via %d addresses", len(ips))
} }
func (l *linkTCP) listen(url *url.URL, sintf string) (*Listener, error) { func (l *linkTCP) listen(url *url.URL, sintf string) (*Listener, error) {
@ -82,10 +98,9 @@ func (l *linkTCP) listen(url *url.URL, sintf string) (*Listener, error) {
cancel() cancel()
break break
} }
laddr := conn.LocalAddr().(*net.TCPAddr)
raddr := conn.RemoteAddr().(*net.TCPAddr) raddr := conn.RemoteAddr().(*net.TCPAddr)
name := fmt.Sprintf("tcp://%s", raddr) name := fmt.Sprintf("tcp://%s", raddr)
info := linkInfoFor("tcp", sintf, tcpIDFor(laddr, raddr)) info := linkInfoFor("tcp", sintf, raddr.String())
if err = l.handler(name, info, conn, linkOptionsForListener(url), true, raddr.IP.IsLinkLocalUnicast()); err != nil { if err = l.handler(name, info, conn, linkOptionsForListener(url), true, raddr.IP.IsLinkLocalUnicast()); err != nil {
l.core.log.Errorln("Failed to create inbound link:", err) l.core.log.Errorln("Failed to create inbound link:", err)
} }
@ -180,16 +195,3 @@ func (l *linkTCP) dialerFor(dst *net.TCPAddr, sintf string) (*net.Dialer, error)
} }
return dialer, nil return dialer, nil
} }
func tcpIDFor(local net.Addr, remoteAddr *net.TCPAddr) string {
if localAddr, ok := local.(*net.TCPAddr); ok && localAddr.IP.Equal(remoteAddr.IP) {
// Nodes running on the same host — include both the IP and port.
return remoteAddr.String()
}
if remoteAddr.IP.IsLinkLocalUnicast() {
// Nodes discovered via multicast — include the IP only.
return remoteAddr.IP.String()
}
// Nodes connected remotely — include both the IP and port.
return remoteAddr.String()
}

View File

@ -13,6 +13,7 @@ import (
"math/big" "math/big"
"net" "net"
"net/url" "net/url"
"strconv"
"strings" "strings"
"time" "time"
@ -47,17 +48,30 @@ func (l *links) newLinkTLS(tcp *linkTCP) *linkTLS {
} }
func (l *linkTLS) dial(url *url.URL, options linkOptions, sintf, sni string) error { func (l *linkTLS) dial(url *url.URL, options linkOptions, sintf, sni string) error {
addr, err := net.ResolveTCPAddr("tcp", url.Host) info := linkInfoFor("tls", sintf, url.Host)
if l.links.isConnectedTo(info) {
return nil
}
host, p, err := net.SplitHostPort(url.Host)
if err != nil { if err != nil {
return err return err
} }
port, err := strconv.Atoi(p)
if err != nil {
return err
}
ips, err := net.LookupIP(host)
if err != nil {
return err
}
for _, ip := range ips {
addr := &net.TCPAddr{
IP: ip,
Port: port,
}
dialer, err := l.tcp.dialerFor(addr, sintf) dialer, err := l.tcp.dialerFor(addr, sintf)
if err != nil { if err != nil {
return err continue
}
info := linkInfoFor("tls", sintf, tcpIDFor(dialer.LocalAddr, addr))
if l.links.isConnectedTo(info) {
return nil
} }
tlsconfig := l.config.Clone() tlsconfig := l.config.Clone()
tlsconfig.ServerName = sni tlsconfig.ServerName = sni
@ -67,10 +81,12 @@ func (l *linkTLS) dial(url *url.URL, options linkOptions, sintf, sni string) err
} }
conn, err := tlsdialer.DialContext(l.core.ctx, "tcp", addr.String()) conn, err := tlsdialer.DialContext(l.core.ctx, "tcp", addr.String())
if err != nil { if err != nil {
return err continue
} }
uri := strings.TrimRight(strings.SplitN(url.String(), "?", 2)[0], "/") uri := strings.TrimRight(strings.SplitN(url.String(), "?", 2)[0], "/")
return l.handler(uri, info, conn, options, false, false) return l.handler(uri, info, conn, options, false, false)
}
return fmt.Errorf("failed to connect via %d addresses", len(ips))
} }
func (l *linkTLS) listen(url *url.URL, sintf string) (*Listener, error) { func (l *linkTLS) listen(url *url.URL, sintf string) (*Listener, error) {
@ -105,10 +121,9 @@ func (l *linkTLS) listen(url *url.URL, sintf string) (*Listener, error) {
cancel() cancel()
break break
} }
laddr := conn.LocalAddr().(*net.TCPAddr)
raddr := conn.RemoteAddr().(*net.TCPAddr) raddr := conn.RemoteAddr().(*net.TCPAddr)
name := fmt.Sprintf("tls://%s", raddr) name := fmt.Sprintf("tls://%s", raddr)
info := linkInfoFor("tls", sintf, tcpIDFor(laddr, raddr)) info := linkInfoFor("tls", sintf, raddr.String())
if err = l.handler(name, info, conn, linkOptionsForListener(url), true, raddr.IP.IsLinkLocalUnicast()); err != nil { if err = l.handler(name, info, conn, linkOptionsForListener(url), true, raddr.IP.IsLinkLocalUnicast()); err != nil {
l.core.log.Errorln("Failed to create inbound link:", err) l.core.log.Errorln("Failed to create inbound link:", err)
} }