net/dns/resolver: make DoH dialer use existing dnscache happy eyeball dialer

Simplify the ability to reason about the DoH dialing code by reusing the
dnscache's dialer we already have.

Also, reduce the scope of the "ip" variable we don't want to close over.

This necessarily adds a new field to dnscache.Resolver:
SingleHostStaticResult, for when the caller already knows the IPs to be
returned.

Change-Id: I9f2aef7926f649137a5a3e63eebad6a3fffa48c0
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick
2022-04-18 12:50:26 -07:00
committed by Brad Fitzpatrick
parent e96dd00652
commit ecea6cb994
4 changed files with 113 additions and 15 deletions

View File

@@ -15,6 +15,7 @@ import (
"math/rand"
"net"
"net/http"
"net/url"
"runtime"
"sort"
"strconv"
@@ -26,6 +27,7 @@ import (
"inet.af/netaddr"
"tailscale.com/hostinfo"
"tailscale.com/net/dns/publicdns"
"tailscale.com/net/dnscache"
"tailscale.com/net/neterror"
"tailscale.com/net/netns"
"tailscale.com/net/tsdial"
@@ -332,21 +334,47 @@ func (f *forwarder) packetListener(ip netaddr.IP) (packetListener, error) {
return lc, nil
}
// getKnownDoHClient returns an HTTP client for a DoH provider (such as Google
// or Cloudflare DNS), as a function of one of its (usually four) IPs.
//
// The provided IP is only used to determine the DoH provider; it is not
// prioritized among the set of IPs that are used by the provider.
func (f *forwarder) getKnownDoHClient(ip netaddr.IP) (urlBase string, c *http.Client, ok bool) {
urlBase, ok = publicdns.KnownDoH()[ip]
if !ok {
return
return "", nil, false
}
c, ok = f.getKnownDoHClientForProvider(urlBase)
if !ok {
return "", nil, false
}
return urlBase, c, true
}
// getKnownDoHClientForProvider returns an HTTP client for a specific DoH
// provider named by its DoH base URL (like "https://dns.google/dns-query").
//
// The returned client race/Happy Eyeballs dials all IPs for urlBase (usually
// 4), as statically known by the publicdns package.
func (f *forwarder) getKnownDoHClientForProvider(urlBase string) (c *http.Client, ok bool) {
f.mu.Lock()
defer f.mu.Unlock()
if c, ok := f.dohClient[urlBase]; ok {
return urlBase, c, true
return c, true
}
if f.dohClient == nil {
f.dohClient = map[string]*http.Client{}
allIPs := publicdns.DoHIPsOfBase()[urlBase]
if len(allIPs) == 0 {
return nil, false
}
dohURL, err := url.Parse(urlBase)
if err != nil {
return nil, false
}
nsDialer := netns.NewDialer(f.logf)
dialer := dnscache.Dialer(nsDialer.DialContext, &dnscache.Resolver{
SingleHost: dohURL.Hostname(),
SingleHostStaticResult: allIPs,
})
c = &http.Client{
Transport: &http.Transport{
IdleConnTimeout: dohTransportTimeout,
@@ -354,21 +382,15 @@ func (f *forwarder) getKnownDoHClient(ip netaddr.IP) (urlBase string, c *http.Cl
if !strings.HasPrefix(netw, "tcp") {
return nil, fmt.Errorf("unexpected network %q", netw)
}
c, err := nsDialer.DialContext(ctx, "tcp", net.JoinHostPort(ip.String(), "443"))
// If v4 failed, try an equivalent v6 also in the time remaining.
if err != nil && ctx.Err() == nil {
if ip6, ok := publicdns.DoHV6(urlBase); ok && ip.Is4() {
if c6, err := nsDialer.DialContext(ctx, "tcp", net.JoinHostPort(ip6.String(), "443")); err == nil {
return c6, nil
}
}
}
return c, err
return dialer(ctx, netw, addr)
},
},
}
if f.dohClient == nil {
f.dohClient = map[string]*http.Client{}
}
f.dohClient[urlBase] = c
return urlBase, c, true
return c, true
}
const dohType = "application/dns-message"

View File

@@ -5,6 +5,7 @@
package resolver
import (
"flag"
"fmt"
"net"
"reflect"
@@ -169,6 +170,25 @@ func TestMaxDoHInFlight(t *testing.T) {
}
}
var testDNS = flag.Bool("test-dns", false, "run tests that require a working DNS server")
func TestGetKnownDoHClientForProvider(t *testing.T) {
var fwd forwarder
c, ok := fwd.getKnownDoHClientForProvider("https://dns.google/dns-query")
if !ok {
t.Fatal("not found")
}
if !*testDNS {
t.Skip("skipping without --test-dns")
}
res, err := c.Head("https://dns.google/")
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()
t.Logf("Got: %+v", res)
}
func BenchmarkNameFromQuery(b *testing.B) {
builder := dns.NewBuilder(nil, dns.Header{})
builder.StartQuestions()