mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-29 04:55:31 +00:00
net/dnscache: make Dialer try all resolved IPs
Tested manually with: $ go test -v ./net/dnscache/ -dial-test=bogusplane.dev.tailscale.com:80 Where bogusplane has three A records, only one of which works. Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
dfa5e38fad
commit
281d503626
@ -314,20 +314,19 @@ func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc {
|
||||
// Return with original error
|
||||
return
|
||||
}
|
||||
for _, ip := range ips {
|
||||
dst := net.JoinHostPort(ip.String(), port)
|
||||
if c, err := fwd(ctx, network, dst); err == nil {
|
||||
if c, err := raceDial(ctx, fwd, network, ips, port); err == nil {
|
||||
retConn = c
|
||||
ret = nil
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
ip, ip6, _, err := dnsCache.LookupIP(ctx, host)
|
||||
ip, ip6, allIPs, err := dnsCache.LookupIP(ctx, host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve %q: %w", host, err)
|
||||
}
|
||||
i4s := v4addrs(allIPs)
|
||||
if len(i4s) < 2 {
|
||||
dst := net.JoinHostPort(ip.String(), port)
|
||||
if debug {
|
||||
log.Printf("dnscache: dialing %s, %s for %s", network, dst, address)
|
||||
@ -337,16 +336,107 @@ func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc {
|
||||
return c, err
|
||||
}
|
||||
// Fall back to trying IPv6.
|
||||
// TODO(bradfitz): this is a primarily for IPv6-only
|
||||
// hosts; it's not supposed to be a real Happy
|
||||
// Eyeballs implementation. We should use the net
|
||||
// package's implementation of that by plumbing this
|
||||
// dnscache impl into net.Dialer.Resolver.Dial and
|
||||
// unmarshal/marshal DNS queries/responses to the net
|
||||
// package. This works for v6-only hosts for now.
|
||||
dst = net.JoinHostPort(ip6.String(), port)
|
||||
return fwd(ctx, network, dst)
|
||||
}
|
||||
|
||||
// Multiple IPv4 candidates, and 0+ IPv6.
|
||||
ipsToTry := append(i4s, v6addrs(allIPs)...)
|
||||
return raceDial(ctx, fwd, network, ipsToTry, port)
|
||||
}
|
||||
}
|
||||
|
||||
// fallbackDelay is how long to wait between trying subsequent
|
||||
// addresses when multiple options are available.
|
||||
// 300ms is the same as Go's Happy Eyeballs fallbackDelay value.
|
||||
const fallbackDelay = 300 * time.Millisecond
|
||||
|
||||
// raceDial tries to dial port on each ip in ips, starting a new race
|
||||
// dial every 300ms apart, returning whichever completes first.
|
||||
func raceDial(ctx context.Context, fwd DialContextFunc, network string, ips []netaddr.IP, port string) (net.Conn, error) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
type res struct {
|
||||
c net.Conn
|
||||
err error
|
||||
}
|
||||
resc := make(chan res) // must be unbuffered
|
||||
failBoost := make(chan struct{}) // best effort send on dial failure
|
||||
|
||||
go func() {
|
||||
for i, ip := range ips {
|
||||
if i != 0 {
|
||||
timer := time.NewTimer(fallbackDelay)
|
||||
select {
|
||||
case <-timer.C:
|
||||
case <-failBoost:
|
||||
timer.Stop()
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
go func(ip netaddr.IP) {
|
||||
c, err := fwd(ctx, network, net.JoinHostPort(ip.String(), port))
|
||||
if err != nil {
|
||||
// Best effort wake-up a pending dial.
|
||||
// e.g. IPv4 dials failing quickly on an IPv6-only system.
|
||||
// In that case we don't want to wait 300ms per IPv4 before
|
||||
// we get to the IPv6 addresses.
|
||||
select {
|
||||
case failBoost <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
select {
|
||||
case resc <- res{c, err}:
|
||||
case <-ctx.Done():
|
||||
if c != nil {
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
}(ip)
|
||||
}
|
||||
}()
|
||||
|
||||
var firstErr error
|
||||
var fails int
|
||||
for {
|
||||
select {
|
||||
case r := <-resc:
|
||||
if r.c != nil {
|
||||
return r.c, nil
|
||||
}
|
||||
fails++
|
||||
if firstErr == nil {
|
||||
firstErr = r.err
|
||||
}
|
||||
if fails == len(ips) {
|
||||
return nil, firstErr
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func v4addrs(aa []net.IPAddr) (ret []netaddr.IP) {
|
||||
for _, a := range aa {
|
||||
if ip, ok := netaddr.FromStdIP(a.IP); ok && ip.Is4() {
|
||||
ret = append(ret, ip)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func v6addrs(aa []net.IPAddr) (ret []netaddr.IP) {
|
||||
for _, a := range aa {
|
||||
if ip, ok := netaddr.FromStdIP(a.IP); ok && ip.Is6() {
|
||||
ret = append(ret, ip)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
var errTLSHandshakeTimeout = errors.New("timeout doing TLS handshake")
|
||||
|
@ -5,10 +5,15 @@
|
||||
package dnscache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial")
|
||||
|
||||
func TestIsPrivateIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
ip string
|
||||
@ -26,3 +31,21 @@ func TestIsPrivateIP(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDialer(t *testing.T) {
|
||||
if *dialTest == "" {
|
||||
t.Skip("skipping; --dial-test is blank")
|
||||
}
|
||||
r := new(Resolver)
|
||||
var std net.Dialer
|
||||
dialer := Dialer(std.DialContext, r)
|
||||
t0 := time.Now()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
c, err := dialer(ctx, "tcp", *dialTest)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("dialed in %v", time.Since(t0))
|
||||
c.Close()
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user