From 8d6cf14456c5393ba743620c27b9189381ba4736 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 14 Feb 2022 13:25:19 -0800 Subject: [PATCH] net/dnscache: don't do bootstrap DNS lookup after most failed dials If we've already connected to a certain name's IP in the past, don't assume the problem was DNS related. That just puts unnecessarily load on our bootstrap DNS servers during regular restarts of Tailscale infrastructure components. Also, if we do do a bootstrap DNS lookup and it gives the same IP(s) that we already tried, don't try them again. Change-Id: I743e8991a7f957381b8e4c1508b8e9d0df1782fe Signed-off-by: Brad Fitzpatrick --- net/dnscache/dnscache.go | 119 +++++++++++++++++++++++++++++----- net/dnscache/dnscache_test.go | 79 ++++++++++++++++++++++ 2 files changed, 181 insertions(+), 17 deletions(-) diff --git a/net/dnscache/dnscache.go b/net/dnscache/dnscache.go index 835158de7..cf1ef5817 100644 --- a/net/dnscache/dnscache.go +++ b/net/dnscache/dnscache.go @@ -279,8 +279,9 @@ func (r *Resolver) addIPCache(host string, ip, ip6 net.IP, allIPs []net.IPAddr, // Dialer returns a wrapped DialContext func that uses the provided dnsCache. func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc { d := &dialer{ - fwd: fwd, - dnsCache: dnsCache, + fwd: fwd, + dnsCache: dnsCache, + pastConnect: map[netaddr.IP]time.Time{}, } return d.DialContext } @@ -289,6 +290,9 @@ func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc { type dialer struct { fwd DialContextFunc dnsCache *Resolver + + mu sync.Mutex + pastConnect map[netaddr.IP]time.Time } func (d *dialer) DialContext(ctx context.Context, network, address string) (retConn net.Conn, ret error) { @@ -306,8 +310,9 @@ func (d *dialer) DialContext(ctx context.Context, network, address string) (retC port: port, } defer func() { - // On any failure, assume our DNS is wrong and try our fallback, if any. - if ret == nil || d.dnsCache.LookupIPFallback == nil { + // On failure, consider that our DNS might be wrong and ask the DNS fallback mechanism for + // some other IPs to try. + if ret == nil || d.dnsCache.LookupIPFallback == nil || dc.dnsWasTrustworthy() { return } ips, err := d.dnsCache.LookupIPFallback(ctx, host) @@ -328,17 +333,23 @@ func (d *dialer) DialContext(ctx context.Context, network, address string) (retC } i4s := v4addrs(allIPs) if len(i4s) < 2 { - dst := net.JoinHostPort(ip.String(), port) if debug { - log.Printf("dnscache: dialing %s, %s for %s", network, dst, address) + log.Printf("dnscache: dialing %s, %s for %s", network, ip, address) } - c, err := d.fwd(ctx, network, dst) - if err == nil || ctx.Err() != nil || ip6 == nil { + ipNA, ok := netaddr.FromStdIP(ip) + if !ok { + return nil, fmt.Errorf("invalid IP %q", ip) + } + c, err := dc.dialOne(ctx, ipNA) + if err == nil || ctx.Err() != nil { return c, err } - // Fall back to trying IPv6. - dst = net.JoinHostPort(ip6.String(), port) - return d.fwd(ctx, network, dst) + // Fall back to trying IPv6, if any. + ip6NA, ok := netaddr.FromStdIP(ip6) + if !ok { + return nil, err + } + return dc.dialOne(ctx, ip6NA) } // Multiple IPv4 candidates, and 0+ IPv6. @@ -350,6 +361,77 @@ func (d *dialer) DialContext(ctx context.Context, network, address string) (retC type dialCall struct { d *dialer network, address, host, port string + + mu sync.Mutex // lock ordering: dialer.mu, then dialCall.mu + fails map[netaddr.IP]error // set of IPs that failed to dial thus far +} + +// dnsWasTrustworthy reports whether we think the IP address(es) we +// tried (and failed) to dial were probably the correct IPs. Currently +// the heuristic is whether they ever worked previously. +func (dc *dialCall) dnsWasTrustworthy() bool { + dc.d.mu.Lock() + defer dc.d.mu.Unlock() + dc.mu.Lock() + defer dc.mu.Unlock() + + if len(dc.fails) == 0 { + // No information. + return false + } + + // If any of the IPs we failed to dial worked previously in + // this dialer, assume the DNS is fine. + for ip := range dc.fails { + if _, ok := dc.d.pastConnect[ip]; ok { + return true + } + } + return false +} + +func (dc *dialCall) dialOne(ctx context.Context, ip netaddr.IP) (net.Conn, error) { + c, err := dc.d.fwd(ctx, dc.network, net.JoinHostPort(ip.String(), dc.port)) + dc.noteDialResult(ip, err) + return c, err +} + +// noteDialResult records that a dial to ip either succeeded or +// failed. +func (dc *dialCall) noteDialResult(ip netaddr.IP, err error) { + if err == nil { + d := dc.d + d.mu.Lock() + defer d.mu.Unlock() + d.pastConnect[ip] = time.Now() + return + } + dc.mu.Lock() + defer dc.mu.Unlock() + if dc.fails == nil { + dc.fails = map[netaddr.IP]error{} + } + dc.fails[ip] = err +} + +// uniqueIPs returns a possibly-mutated subslice of ips, filtering out +// dups and ones that have already failed previously. +func (dc *dialCall) uniqueIPs(ips []netaddr.IP) (ret []netaddr.IP) { + dc.mu.Lock() + defer dc.mu.Unlock() + seen := map[netaddr.IP]bool{} + ret = ips[:0] + for _, ip := range ips { + if seen[ip] { + continue + } + seen[ip] = true + if dc.fails[ip] != nil { + continue + } + ret = append(ret, ip) + } + return ret } // fallbackDelay is how long to wait between trying subsequent @@ -360,11 +442,6 @@ type dialCall struct { // raceDial tries to dial port on each ip in ips, starting a new race // dial every fallbackDelay apart, returning whichever completes first. func (dc *dialCall) raceDial(ctx context.Context, ips []netaddr.IP) (net.Conn, error) { - var ( - fwd = dc.d.fwd - network = dc.network - port = dc.port - ) ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -375,6 +452,14 @@ type res struct { resc := make(chan res) // must be unbuffered failBoost := make(chan struct{}) // best effort send on dial failure + // Remove IPs that we tried & failed to dial previously + // (such as when we're being called after a dnsfallback lookup and get + // the same results) + ips = dc.uniqueIPs(ips) + if len(ips) == 0 { + return nil, errors.New("no IPs") + } + go func() { for i, ip := range ips { if i != 0 { @@ -389,7 +474,7 @@ type res struct { } } go func(ip netaddr.IP) { - c, err := fwd(ctx, network, net.JoinHostPort(ip.String(), port)) + c, err := dc.dialOne(ctx, ip) if err != nil { // Best effort wake-up a pending dial. // e.g. IPv4 dials failing quickly on an IPv6-only system. diff --git a/net/dnscache/dnscache_test.go b/net/dnscache/dnscache_test.go index 096049ccf..10cfd5398 100644 --- a/net/dnscache/dnscache_test.go +++ b/net/dnscache/dnscache_test.go @@ -6,10 +6,14 @@ import ( "context" + "errors" "flag" "net" + "reflect" "testing" "time" + + "inet.af/netaddr" ) var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial") @@ -31,3 +35,78 @@ func TestDialer(t *testing.T) { t.Logf("dialed in %v", time.Since(t0)) c.Close() } + +func TestDialCall_DNSWasTrustworthy(t *testing.T) { + type step struct { + ip netaddr.IP // IP we pretended to dial + err error // the dial error or nil for success + } + mustIP := netaddr.MustParseIP + errFail := errors.New("some connect failure") + tests := []struct { + name string + steps []step + want bool + }{ + { + name: "no-info", + want: false, + }, + { + name: "previous-dial", + steps: []step{ + {mustIP("2003::1"), nil}, + {mustIP("2003::1"), errFail}, + }, + want: true, + }, + { + name: "no-previous-dial", + steps: []step{ + {mustIP("2003::1"), errFail}, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := &dialer{ + pastConnect: map[netaddr.IP]time.Time{}, + } + dc := &dialCall{ + d: d, + } + for _, st := range tt.steps { + dc.noteDialResult(st.ip, st.err) + } + got := dc.dnsWasTrustworthy() + if got != tt.want { + t.Errorf("got %v; want %v", got, tt.want) + } + }) + } +} + +func TestDialCall_uniqueIPs(t *testing.T) { + dc := &dialCall{} + mustIP := netaddr.MustParseIP + errFail := errors.New("some connect failure") + dc.noteDialResult(mustIP("2003::1"), errFail) + dc.noteDialResult(mustIP("2003::2"), errFail) + got := dc.uniqueIPs([]netaddr.IP{ + mustIP("2003::1"), + mustIP("2003::2"), + mustIP("2003::2"), + mustIP("2003::3"), + mustIP("2003::3"), + mustIP("2003::4"), + mustIP("2003::4"), + }) + want := []netaddr.IP{ + mustIP("2003::3"), + mustIP("2003::4"), + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v; want %v", got, want) + } +}