net/dnscache: don't do bootstrap DNS lookup after most failed dials

If we've already connected to a certain name's IP in the past, don't
assume the problem was DNS related. That just puts unnecessarily load
on our bootstrap DNS servers during regular restarts of Tailscale
infrastructure components.

Also, if we do do a bootstrap DNS lookup and it gives the same IP(s)
that we already tried, don't try them again.

Change-Id: I743e8991a7f957381b8e4c1508b8e9d0df1782fe
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2022-02-14 13:25:19 -08:00 committed by Brad Fitzpatrick
parent b4947be0c8
commit 8d6cf14456
2 changed files with 181 additions and 17 deletions

View File

@ -279,8 +279,9 @@ func (r *Resolver) addIPCache(host string, ip, ip6 net.IP, allIPs []net.IPAddr,
// Dialer returns a wrapped DialContext func that uses the provided dnsCache.
func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc {
d := &dialer{
fwd: fwd,
dnsCache: dnsCache,
fwd: fwd,
dnsCache: dnsCache,
pastConnect: map[netaddr.IP]time.Time{},
}
return d.DialContext
}
@ -289,6 +290,9 @@ func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc {
type dialer struct {
fwd DialContextFunc
dnsCache *Resolver
mu sync.Mutex
pastConnect map[netaddr.IP]time.Time
}
func (d *dialer) DialContext(ctx context.Context, network, address string) (retConn net.Conn, ret error) {
@ -306,8 +310,9 @@ func (d *dialer) DialContext(ctx context.Context, network, address string) (retC
port: port,
}
defer func() {
// On any failure, assume our DNS is wrong and try our fallback, if any.
if ret == nil || d.dnsCache.LookupIPFallback == nil {
// On failure, consider that our DNS might be wrong and ask the DNS fallback mechanism for
// some other IPs to try.
if ret == nil || d.dnsCache.LookupIPFallback == nil || dc.dnsWasTrustworthy() {
return
}
ips, err := d.dnsCache.LookupIPFallback(ctx, host)
@ -328,17 +333,23 @@ func (d *dialer) DialContext(ctx context.Context, network, address string) (retC
}
i4s := v4addrs(allIPs)
if len(i4s) < 2 {
dst := net.JoinHostPort(ip.String(), port)
if debug {
log.Printf("dnscache: dialing %s, %s for %s", network, dst, address)
log.Printf("dnscache: dialing %s, %s for %s", network, ip, address)
}
c, err := d.fwd(ctx, network, dst)
if err == nil || ctx.Err() != nil || ip6 == nil {
ipNA, ok := netaddr.FromStdIP(ip)
if !ok {
return nil, fmt.Errorf("invalid IP %q", ip)
}
c, err := dc.dialOne(ctx, ipNA)
if err == nil || ctx.Err() != nil {
return c, err
}
// Fall back to trying IPv6.
dst = net.JoinHostPort(ip6.String(), port)
return d.fwd(ctx, network, dst)
// Fall back to trying IPv6, if any.
ip6NA, ok := netaddr.FromStdIP(ip6)
if !ok {
return nil, err
}
return dc.dialOne(ctx, ip6NA)
}
// Multiple IPv4 candidates, and 0+ IPv6.
@ -350,6 +361,77 @@ func (d *dialer) DialContext(ctx context.Context, network, address string) (retC
type dialCall struct {
d *dialer
network, address, host, port string
mu sync.Mutex // lock ordering: dialer.mu, then dialCall.mu
fails map[netaddr.IP]error // set of IPs that failed to dial thus far
}
// dnsWasTrustworthy reports whether we think the IP address(es) we
// tried (and failed) to dial were probably the correct IPs. Currently
// the heuristic is whether they ever worked previously.
func (dc *dialCall) dnsWasTrustworthy() bool {
dc.d.mu.Lock()
defer dc.d.mu.Unlock()
dc.mu.Lock()
defer dc.mu.Unlock()
if len(dc.fails) == 0 {
// No information.
return false
}
// If any of the IPs we failed to dial worked previously in
// this dialer, assume the DNS is fine.
for ip := range dc.fails {
if _, ok := dc.d.pastConnect[ip]; ok {
return true
}
}
return false
}
func (dc *dialCall) dialOne(ctx context.Context, ip netaddr.IP) (net.Conn, error) {
c, err := dc.d.fwd(ctx, dc.network, net.JoinHostPort(ip.String(), dc.port))
dc.noteDialResult(ip, err)
return c, err
}
// noteDialResult records that a dial to ip either succeeded or
// failed.
func (dc *dialCall) noteDialResult(ip netaddr.IP, err error) {
if err == nil {
d := dc.d
d.mu.Lock()
defer d.mu.Unlock()
d.pastConnect[ip] = time.Now()
return
}
dc.mu.Lock()
defer dc.mu.Unlock()
if dc.fails == nil {
dc.fails = map[netaddr.IP]error{}
}
dc.fails[ip] = err
}
// uniqueIPs returns a possibly-mutated subslice of ips, filtering out
// dups and ones that have already failed previously.
func (dc *dialCall) uniqueIPs(ips []netaddr.IP) (ret []netaddr.IP) {
dc.mu.Lock()
defer dc.mu.Unlock()
seen := map[netaddr.IP]bool{}
ret = ips[:0]
for _, ip := range ips {
if seen[ip] {
continue
}
seen[ip] = true
if dc.fails[ip] != nil {
continue
}
ret = append(ret, ip)
}
return ret
}
// fallbackDelay is how long to wait between trying subsequent
@ -360,11 +442,6 @@ type dialCall struct {
// raceDial tries to dial port on each ip in ips, starting a new race
// dial every fallbackDelay apart, returning whichever completes first.
func (dc *dialCall) raceDial(ctx context.Context, ips []netaddr.IP) (net.Conn, error) {
var (
fwd = dc.d.fwd
network = dc.network
port = dc.port
)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
@ -375,6 +452,14 @@ type res struct {
resc := make(chan res) // must be unbuffered
failBoost := make(chan struct{}) // best effort send on dial failure
// Remove IPs that we tried & failed to dial previously
// (such as when we're being called after a dnsfallback lookup and get
// the same results)
ips = dc.uniqueIPs(ips)
if len(ips) == 0 {
return nil, errors.New("no IPs")
}
go func() {
for i, ip := range ips {
if i != 0 {
@ -389,7 +474,7 @@ type res struct {
}
}
go func(ip netaddr.IP) {
c, err := fwd(ctx, network, net.JoinHostPort(ip.String(), port))
c, err := dc.dialOne(ctx, ip)
if err != nil {
// Best effort wake-up a pending dial.
// e.g. IPv4 dials failing quickly on an IPv6-only system.

View File

@ -6,10 +6,14 @@
import (
"context"
"errors"
"flag"
"net"
"reflect"
"testing"
"time"
"inet.af/netaddr"
)
var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial")
@ -31,3 +35,78 @@ func TestDialer(t *testing.T) {
t.Logf("dialed in %v", time.Since(t0))
c.Close()
}
func TestDialCall_DNSWasTrustworthy(t *testing.T) {
type step struct {
ip netaddr.IP // IP we pretended to dial
err error // the dial error or nil for success
}
mustIP := netaddr.MustParseIP
errFail := errors.New("some connect failure")
tests := []struct {
name string
steps []step
want bool
}{
{
name: "no-info",
want: false,
},
{
name: "previous-dial",
steps: []step{
{mustIP("2003::1"), nil},
{mustIP("2003::1"), errFail},
},
want: true,
},
{
name: "no-previous-dial",
steps: []step{
{mustIP("2003::1"), errFail},
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
d := &dialer{
pastConnect: map[netaddr.IP]time.Time{},
}
dc := &dialCall{
d: d,
}
for _, st := range tt.steps {
dc.noteDialResult(st.ip, st.err)
}
got := dc.dnsWasTrustworthy()
if got != tt.want {
t.Errorf("got %v; want %v", got, tt.want)
}
})
}
}
func TestDialCall_uniqueIPs(t *testing.T) {
dc := &dialCall{}
mustIP := netaddr.MustParseIP
errFail := errors.New("some connect failure")
dc.noteDialResult(mustIP("2003::1"), errFail)
dc.noteDialResult(mustIP("2003::2"), errFail)
got := dc.uniqueIPs([]netaddr.IP{
mustIP("2003::1"),
mustIP("2003::2"),
mustIP("2003::2"),
mustIP("2003::3"),
mustIP("2003::3"),
mustIP("2003::4"),
mustIP("2003::4"),
})
want := []netaddr.IP{
mustIP("2003::3"),
mustIP("2003::4"),
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v; want %v", got, want)
}
}