diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index 531c2d130..04f22d8a2 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -17,6 +17,7 @@ import ( "net" "net/http" "runtime" + "sort" "strings" "sync" "time" @@ -40,6 +41,11 @@ const ( // connections open to DNS-over-HTTPs servers. This is pretty // arbitrary. dohTransportTimeout = 30 * time.Second + + // wellKnownHostBackupDelay is how long to artificially delay upstream + // DNS queries to the "fallback" DNS server IP for a known provider + // (e.g. how long to wait to query Google's 8.8.4.4 after 8.8.8.8). + wellKnownHostBackupDelay = 200 * time.Millisecond ) var errNoUpstreams = errors.New("upstream nameservers not set") @@ -152,7 +158,20 @@ func clampEDNSSize(packet []byte, maxSize uint16) { type route struct { Suffix dnsname.FQDN - Resolvers []netaddr.IPPort + Resolvers []resolverAndDelay +} + +// resolverAndDelay is an upstream DNS resolver and a delay for how +// long to wait before querying it. +type resolverAndDelay struct { + // ipp is the upstream resolver. + ipp netaddr.IPPort + + // startDelay is an amount to delay this resolver at + // start. It's used when, say, there are four Google or + // Cloudflare DNS IPs (two IPv4 + two IPv6) and we don't want + // to race all four at once. + startDelay time.Duration } // forwarder forwards DNS packets to a number of upstream nameservers. @@ -204,7 +223,84 @@ func (f *forwarder) Close() error { return nil } -func (f *forwarder) setRoutes(routes []route) { +// resolversWithDelays maps from a set of DNS server ip:ports (currently +// the port is always 53) to a slice of a type that included a +// startDelay. So if ipps contains e.g. four Google DNS IPs (two IPv4 +// + twoIPv6), this function partition adds delays to some. +func resolversWithDelays(ipps []netaddr.IPPort) []resolverAndDelay { + type hostAndFam struct { + host string // some arbitrary string representing DNS host (currently the DoH base) + bits uint8 // either 32 or 128 for IPv4 vs IPv6s address family + } + + // Track how many of each known resolver host are in the list, + // per address family. + total := map[hostAndFam]int{} + + rr := make([]resolverAndDelay, len(ipps)) + for _, ipp := range ipps { + ip := ipp.IP() + if host, ok := knownDoH[ip]; ok { + total[hostAndFam{host, ip.BitLen()}]++ + } + } + + done := map[hostAndFam]int{} + for i, ipp := range ipps { + ip := ipp.IP() + var startDelay time.Duration + if host, ok := knownDoH[ip]; ok { + key4 := hostAndFam{host, 32} + key6 := hostAndFam{host, 128} + switch { + case ip.Is4(): + if done[key4] > 0 { + startDelay += wellKnownHostBackupDelay + } + case ip.Is6(): + total4 := total[key4] + if total4 >= 2 { + // If we have two IPv4 IPs of the same provider + // already in the set, delay the IPv6 queries + // until halfway through the timeout (so wait + // 2.5 seconds). Even the network is IPv6-only, + // the DoH dialer will fallback to IPv6 + // immediately anyway. + startDelay = responseTimeout / 2 + } else if total4 == 1 { + startDelay += wellKnownHostBackupDelay + } + if done[key6] > 0 { + startDelay += wellKnownHostBackupDelay + } + } + done[hostAndFam{host, ip.BitLen()}]++ + } + rr[i] = resolverAndDelay{ + ipp: ipp, + startDelay: startDelay, + } + } + return rr +} + +// setRoutes sets the routes to use for DNS forwarding. It's called by +// Resolver.SetConfig on reconfig. +// +// The memory referenced by routesBySuffix should not be modified. +func (f *forwarder) setRoutes(routesBySuffix map[dnsname.FQDN][]netaddr.IPPort) { + routes := make([]route, 0, len(routesBySuffix)) + for suffix, ipps := range routesBySuffix { + routes = append(routes, route{ + Suffix: suffix, + Resolvers: resolversWithDelays(ipps), + }) + } + // Sort from longest prefix to shortest. + sort.Slice(routes, func(i, j int) bool { + return routes[i].Suffix.NumLabels() > routes[j].Suffix.NumLabels() + }) + f.mu.Lock() defer f.mu.Unlock() f.routes = routes @@ -394,7 +490,7 @@ func (f *forwarder) send(ctx context.Context, fq *forwardQuery, dst netaddr.IPPo } // resolvers returns the resolvers to use for domain. -func (f *forwarder) resolvers(domain dnsname.FQDN) []netaddr.IPPort { +func (f *forwarder) resolvers(domain dnsname.FQDN) []resolverAndDelay { f.mu.Lock() routes := f.routes f.mu.Unlock() @@ -460,9 +556,18 @@ func (f *forwarder) forward(query packet) error { firstErr error ) - for _, ipp := range resolvers { - go func(ipp netaddr.IPPort) { - resb, err := f.send(ctx, fq, ipp) + for _, rr := range resolvers { + go func(rr resolverAndDelay) { + if rr.startDelay > 0 { + timer := time.NewTimer(rr.startDelay) + select { + case <-timer.C: + case <-ctx.Done(): + timer.Stop() + return + } + } + resb, err := f.send(ctx, fq, rr.ipp) if err != nil { mu.Lock() defer mu.Unlock() @@ -475,7 +580,7 @@ func (f *forwarder) forward(query packet) error { case resc <- resb: default: } - }(ipp) + }(rr) } select { diff --git a/net/dns/resolver/tsdns.go b/net/dns/resolver/tsdns.go index fbd8abda0..bdd7b6318 100644 --- a/net/dns/resolver/tsdns.go +++ b/net/dns/resolver/tsdns.go @@ -208,7 +208,6 @@ func (r *Resolver) SetConfig(cfg Config) error { r.saveConfigForTests(cfg) } - routes := make([]route, 0, len(cfg.Routes)) reverse := make(map[netaddr.IP]dnsname.FQDN, len(cfg.Hosts)) for host, ips := range cfg.Hosts { @@ -217,18 +216,7 @@ func (r *Resolver) SetConfig(cfg Config) error { } } - for suffix, ips := range cfg.Routes { - routes = append(routes, route{ - Suffix: suffix, - Resolvers: ips, - }) - } - // Sort from longest prefix to shortest. - sort.Slice(routes, func(i, j int) bool { - return routes[i].Suffix.NumLabels() > routes[j].Suffix.NumLabels() - }) - - r.forwarder.setRoutes(routes) + r.forwarder.setRoutes(cfg.Routes) r.mu.Lock() defer r.mu.Unlock()