mirror of
https://github.com/tailscale/tailscale.git
synced 2025-04-24 09:51:41 +00:00
net/dnscache: make Dialer try all resolved IPs
Tested manually with: $ go test -v ./net/dnscache/ -dial-test=bogusplane.dev.tailscale.com:80 Where bogusplane has three A records, only one of which works. Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
dfa5e38fad
commit
281d503626
@ -314,20 +314,19 @@ func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc {
|
|||||||
// Return with original error
|
// Return with original error
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, ip := range ips {
|
if c, err := raceDial(ctx, fwd, network, ips, port); err == nil {
|
||||||
dst := net.JoinHostPort(ip.String(), port)
|
|
||||||
if c, err := fwd(ctx, network, dst); err == nil {
|
|
||||||
retConn = c
|
retConn = c
|
||||||
ret = nil
|
ret = nil
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
ip, ip6, _, err := dnsCache.LookupIP(ctx, host)
|
ip, ip6, allIPs, err := dnsCache.LookupIP(ctx, host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to resolve %q: %w", host, err)
|
return nil, fmt.Errorf("failed to resolve %q: %w", host, err)
|
||||||
}
|
}
|
||||||
|
i4s := v4addrs(allIPs)
|
||||||
|
if len(i4s) < 2 {
|
||||||
dst := net.JoinHostPort(ip.String(), port)
|
dst := net.JoinHostPort(ip.String(), port)
|
||||||
if debug {
|
if debug {
|
||||||
log.Printf("dnscache: dialing %s, %s for %s", network, dst, address)
|
log.Printf("dnscache: dialing %s, %s for %s", network, dst, address)
|
||||||
@ -337,16 +336,107 @@ func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc {
|
|||||||
return c, err
|
return c, err
|
||||||
}
|
}
|
||||||
// Fall back to trying IPv6.
|
// Fall back to trying IPv6.
|
||||||
// TODO(bradfitz): this is a primarily for IPv6-only
|
|
||||||
// hosts; it's not supposed to be a real Happy
|
|
||||||
// Eyeballs implementation. We should use the net
|
|
||||||
// package's implementation of that by plumbing this
|
|
||||||
// dnscache impl into net.Dialer.Resolver.Dial and
|
|
||||||
// unmarshal/marshal DNS queries/responses to the net
|
|
||||||
// package. This works for v6-only hosts for now.
|
|
||||||
dst = net.JoinHostPort(ip6.String(), port)
|
dst = net.JoinHostPort(ip6.String(), port)
|
||||||
return fwd(ctx, network, dst)
|
return fwd(ctx, network, dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Multiple IPv4 candidates, and 0+ IPv6.
|
||||||
|
ipsToTry := append(i4s, v6addrs(allIPs)...)
|
||||||
|
return raceDial(ctx, fwd, network, ipsToTry, port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fallbackDelay is how long to wait between trying subsequent
|
||||||
|
// addresses when multiple options are available.
|
||||||
|
// 300ms is the same as Go's Happy Eyeballs fallbackDelay value.
|
||||||
|
const fallbackDelay = 300 * time.Millisecond
|
||||||
|
|
||||||
|
// raceDial tries to dial port on each ip in ips, starting a new race
|
||||||
|
// dial every 300ms apart, returning whichever completes first.
|
||||||
|
func raceDial(ctx context.Context, fwd DialContextFunc, network string, ips []netaddr.IP, port string) (net.Conn, error) {
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
type res struct {
|
||||||
|
c net.Conn
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
resc := make(chan res) // must be unbuffered
|
||||||
|
failBoost := make(chan struct{}) // best effort send on dial failure
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for i, ip := range ips {
|
||||||
|
if i != 0 {
|
||||||
|
timer := time.NewTimer(fallbackDelay)
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
case <-failBoost:
|
||||||
|
timer.Stop()
|
||||||
|
case <-ctx.Done():
|
||||||
|
timer.Stop()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
go func(ip netaddr.IP) {
|
||||||
|
c, err := fwd(ctx, network, net.JoinHostPort(ip.String(), port))
|
||||||
|
if err != nil {
|
||||||
|
// Best effort wake-up a pending dial.
|
||||||
|
// e.g. IPv4 dials failing quickly on an IPv6-only system.
|
||||||
|
// In that case we don't want to wait 300ms per IPv4 before
|
||||||
|
// we get to the IPv6 addresses.
|
||||||
|
select {
|
||||||
|
case failBoost <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case resc <- res{c, err}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
if c != nil {
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(ip)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var firstErr error
|
||||||
|
var fails int
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case r := <-resc:
|
||||||
|
if r.c != nil {
|
||||||
|
return r.c, nil
|
||||||
|
}
|
||||||
|
fails++
|
||||||
|
if firstErr == nil {
|
||||||
|
firstErr = r.err
|
||||||
|
}
|
||||||
|
if fails == len(ips) {
|
||||||
|
return nil, firstErr
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func v4addrs(aa []net.IPAddr) (ret []netaddr.IP) {
|
||||||
|
for _, a := range aa {
|
||||||
|
if ip, ok := netaddr.FromStdIP(a.IP); ok && ip.Is4() {
|
||||||
|
ret = append(ret, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func v6addrs(aa []net.IPAddr) (ret []netaddr.IP) {
|
||||||
|
for _, a := range aa {
|
||||||
|
if ip, ok := netaddr.FromStdIP(a.IP); ok && ip.Is6() {
|
||||||
|
ret = append(ret, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
var errTLSHandshakeTimeout = errors.New("timeout doing TLS handshake")
|
var errTLSHandshakeTimeout = errors.New("timeout doing TLS handshake")
|
||||||
|
@ -5,10 +5,15 @@
|
|||||||
package dnscache
|
package dnscache
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"flag"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial")
|
||||||
|
|
||||||
func TestIsPrivateIP(t *testing.T) {
|
func TestIsPrivateIP(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
ip string
|
ip string
|
||||||
@ -26,3 +31,21 @@ func TestIsPrivateIP(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDialer(t *testing.T) {
|
||||||
|
if *dialTest == "" {
|
||||||
|
t.Skip("skipping; --dial-test is blank")
|
||||||
|
}
|
||||||
|
r := new(Resolver)
|
||||||
|
var std net.Dialer
|
||||||
|
dialer := Dialer(std.DialContext, r)
|
||||||
|
t0 := time.Now()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
c, err := dialer(ctx, "tcp", *dialTest)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Logf("dialed in %v", time.Since(t0))
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user