mirror of
https://github.com/tailscale/tailscale.git
synced 2025-08-12 05:37:32 +00:00
net/dns: replace resolver IPs with type for DoH
We currently plumb full URLs for DNS resolvers from the control server down to the client. But when we pass the values into the net/dns package, we throw away any URL that isn't a bare IP. This commit continues the plumbing, and gets the URL all the way to the built in forwarder. (It stops before plumbing URLs into the OS configurations that can handle them.) For #2596 Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
This commit is contained in:

committed by
David Crawshaw

parent
7bfd4f521d
commit
9502b515f1
@@ -50,7 +50,7 @@ func TestDoH(t *testing.T) {
|
||||
|
||||
for ip := range knownDoH {
|
||||
t.Run(ip.String(), func(t *testing.T) {
|
||||
urlBase, c, ok := f.getDoHClient(ip)
|
||||
urlBase, c, ok := f.getKnownDoHClient(ip)
|
||||
if !ok {
|
||||
t.Fatal("expected DoH")
|
||||
}
|
||||
|
@@ -24,6 +24,7 @@ import (
|
||||
dns "golang.org/x/net/dns/dnsmessage"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/net/netns"
|
||||
"tailscale.com/types/dnstype"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/dnsname"
|
||||
"tailscale.com/wgengine/monitor"
|
||||
@@ -133,8 +134,8 @@ type route struct {
|
||||
// resolverAndDelay is an upstream DNS resolver and a delay for how
|
||||
// long to wait before querying it.
|
||||
type resolverAndDelay struct {
|
||||
// ipp is the upstream resolver.
|
||||
ipp netaddr.IPPort
|
||||
// name is the upstream resolver.
|
||||
name dnstype.Resolver
|
||||
|
||||
// startDelay is an amount to delay this resolver at
|
||||
// start. It's used when, say, there are four Google or
|
||||
@@ -158,7 +159,7 @@ type forwarder struct {
|
||||
|
||||
mu sync.Mutex // guards following
|
||||
|
||||
dohClient map[netaddr.IP]*http.Client
|
||||
dohClient map[string]*http.Client // urlBase -> client
|
||||
|
||||
// routes are per-suffix resolvers to use, with
|
||||
// the most specific routes first.
|
||||
@@ -192,11 +193,11 @@ func (f *forwarder) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolversWithDelays maps from a set of DNS server ip:ports (currently
|
||||
// the port is always 53) to a slice of a type that included a
|
||||
// startDelay. So if ipps contains e.g. four Google DNS IPs (two IPv4
|
||||
// + twoIPv6), this function partition adds delays to some.
|
||||
func resolversWithDelays(ipps []netaddr.IPPort) []resolverAndDelay {
|
||||
// resolversWithDelays maps from a set of DNS server names to a slice of
|
||||
// a type that included a startDelay. So if resolvers contains e.g. four
|
||||
// Google DNS IPs (two IPv4 + twoIPv6), this function partition adds
|
||||
// delays to some.
|
||||
func resolversWithDelays(resolvers []dnstype.Resolver) []resolverAndDelay {
|
||||
type hostAndFam struct {
|
||||
host string // some arbitrary string representing DNS host (currently the DoH base)
|
||||
bits uint8 // either 32 or 128 for IPv4 vs IPv6s address family
|
||||
@@ -206,47 +207,49 @@ func resolversWithDelays(ipps []netaddr.IPPort) []resolverAndDelay {
|
||||
// per address family.
|
||||
total := map[hostAndFam]int{}
|
||||
|
||||
rr := make([]resolverAndDelay, len(ipps))
|
||||
for _, ipp := range ipps {
|
||||
ip := ipp.IP()
|
||||
if host, ok := knownDoH[ip]; ok {
|
||||
total[hostAndFam{host, ip.BitLen()}]++
|
||||
rr := make([]resolverAndDelay, len(resolvers))
|
||||
for _, r := range resolvers {
|
||||
if ip, err := netaddr.ParseIP(r.Addr); err == nil {
|
||||
if host, ok := knownDoH[ip]; ok {
|
||||
total[hostAndFam{host, ip.BitLen()}]++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
done := map[hostAndFam]int{}
|
||||
for i, ipp := range ipps {
|
||||
ip := ipp.IP()
|
||||
for i, r := range resolvers {
|
||||
var startDelay time.Duration
|
||||
if host, ok := knownDoH[ip]; ok {
|
||||
key4 := hostAndFam{host, 32}
|
||||
key6 := hostAndFam{host, 128}
|
||||
switch {
|
||||
case ip.Is4():
|
||||
if done[key4] > 0 {
|
||||
startDelay += wellKnownHostBackupDelay
|
||||
}
|
||||
case ip.Is6():
|
||||
total4 := total[key4]
|
||||
if total4 >= 2 {
|
||||
// If we have two IPv4 IPs of the same provider
|
||||
// already in the set, delay the IPv6 queries
|
||||
// until halfway through the timeout (so wait
|
||||
// 2.5 seconds). Even the network is IPv6-only,
|
||||
// the DoH dialer will fallback to IPv6
|
||||
// immediately anyway.
|
||||
startDelay = responseTimeout / 2
|
||||
} else if total4 == 1 {
|
||||
startDelay += wellKnownHostBackupDelay
|
||||
}
|
||||
if done[key6] > 0 {
|
||||
startDelay += wellKnownHostBackupDelay
|
||||
if ip, err := netaddr.ParseIP(r.Addr); err == nil {
|
||||
if host, ok := knownDoH[ip]; ok {
|
||||
key4 := hostAndFam{host, 32}
|
||||
key6 := hostAndFam{host, 128}
|
||||
switch {
|
||||
case ip.Is4():
|
||||
if done[key4] > 0 {
|
||||
startDelay += wellKnownHostBackupDelay
|
||||
}
|
||||
case ip.Is6():
|
||||
total4 := total[key4]
|
||||
if total4 >= 2 {
|
||||
// If we have two IPv4 IPs of the same provider
|
||||
// already in the set, delay the IPv6 queries
|
||||
// until halfway through the timeout (so wait
|
||||
// 2.5 seconds). Even the network is IPv6-only,
|
||||
// the DoH dialer will fallback to IPv6
|
||||
// immediately anyway.
|
||||
startDelay = responseTimeout / 2
|
||||
} else if total4 == 1 {
|
||||
startDelay += wellKnownHostBackupDelay
|
||||
}
|
||||
if done[key6] > 0 {
|
||||
startDelay += wellKnownHostBackupDelay
|
||||
}
|
||||
}
|
||||
done[hostAndFam{host, ip.BitLen()}]++
|
||||
}
|
||||
done[hostAndFam{host, ip.BitLen()}]++
|
||||
}
|
||||
rr[i] = resolverAndDelay{
|
||||
ipp: ipp,
|
||||
name: r,
|
||||
startDelay: startDelay,
|
||||
}
|
||||
}
|
||||
@@ -257,12 +260,12 @@ func resolversWithDelays(ipps []netaddr.IPPort) []resolverAndDelay {
|
||||
// Resolver.SetConfig on reconfig.
|
||||
//
|
||||
// The memory referenced by routesBySuffix should not be modified.
|
||||
func (f *forwarder) setRoutes(routesBySuffix map[dnsname.FQDN][]netaddr.IPPort) {
|
||||
func (f *forwarder) setRoutes(routesBySuffix map[dnsname.FQDN][]dnstype.Resolver) {
|
||||
routes := make([]route, 0, len(routesBySuffix))
|
||||
for suffix, ipps := range routesBySuffix {
|
||||
for suffix, rs := range routesBySuffix {
|
||||
routes = append(routes, route{
|
||||
Suffix: suffix,
|
||||
Resolvers: resolversWithDelays(ipps),
|
||||
Resolvers: resolversWithDelays(rs),
|
||||
})
|
||||
}
|
||||
// Sort from longest prefix to shortest.
|
||||
@@ -296,18 +299,19 @@ func (f *forwarder) packetListener(ip netaddr.IP) (packetListener, error) {
|
||||
return lc, nil
|
||||
}
|
||||
|
||||
func (f *forwarder) getDoHClient(ip netaddr.IP) (urlBase string, c *http.Client, ok bool) {
|
||||
func (f *forwarder) getKnownDoHClient(ip netaddr.IP) (urlBase string, c *http.Client, ok bool) {
|
||||
urlBase, ok = knownDoH[ip]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
if c, ok := f.dohClient[ip]; ok {
|
||||
if c, ok := f.dohClient[urlBase]; ok {
|
||||
return urlBase, c, true
|
||||
}
|
||||
if f.dohClient == nil {
|
||||
f.dohClient = map[netaddr.IP]*http.Client{}
|
||||
f.dohClient = map[string]*http.Client{}
|
||||
}
|
||||
nsDialer := netns.NewDialer()
|
||||
c = &http.Client{
|
||||
@@ -330,7 +334,7 @@ func (f *forwarder) getDoHClient(ip netaddr.IP) (urlBase string, c *http.Client,
|
||||
},
|
||||
},
|
||||
}
|
||||
f.dohClient[ip] = c
|
||||
f.dohClient[urlBase] = c
|
||||
return urlBase, c, true
|
||||
}
|
||||
|
||||
@@ -380,20 +384,32 @@ func (f *forwarder) sendDoH(ctx context.Context, urlBase string, c *http.Client,
|
||||
// send sends packet to dst. It is best effort.
|
||||
//
|
||||
// send expects the reply to have the same txid as txidOut.
|
||||
//
|
||||
func (f *forwarder) send(ctx context.Context, fq *forwardQuery, dst netaddr.IPPort) ([]byte, error) {
|
||||
ip := dst.IP()
|
||||
func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDelay) ([]byte, error) {
|
||||
if strings.HasPrefix(rr.name.Addr, "http://") {
|
||||
return nil, fmt.Errorf("http:// resolvers not supported yet")
|
||||
}
|
||||
if strings.HasPrefix(rr.name.Addr, "https://") {
|
||||
return nil, fmt.Errorf("https:// resolvers not supported yet")
|
||||
}
|
||||
if strings.HasPrefix(rr.name.Addr, "tls://") {
|
||||
return nil, fmt.Errorf("tls:// resolvers not supported yet")
|
||||
}
|
||||
ipp, err := netaddr.ParseIPPort(rr.name.Addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Upgrade known DNS IPs to DoH (DNS-over-HTTPs).
|
||||
if urlBase, dc, ok := f.getDoHClient(ip); ok {
|
||||
// All known DoH is over port 53.
|
||||
if urlBase, dc, ok := f.getKnownDoHClient(ipp.IP()); ok {
|
||||
res, err := f.sendDoH(ctx, urlBase, dc, fq.packet)
|
||||
if err == nil || ctx.Err() != nil {
|
||||
return res, err
|
||||
}
|
||||
f.logf("DoH error from %v: %v", ip, err)
|
||||
f.logf("DoH error from %v: %v", ipp.IP(), err)
|
||||
}
|
||||
|
||||
ln, err := f.packetListener(ip)
|
||||
ln, err := f.packetListener(ipp.IP())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -407,7 +423,7 @@ func (f *forwarder) send(ctx context.Context, fq *forwardQuery, dst netaddr.IPPo
|
||||
fq.closeOnCtxDone.Add(conn)
|
||||
defer fq.closeOnCtxDone.Remove(conn)
|
||||
|
||||
if _, err := conn.WriteTo(fq.packet, dst.UDPAddr()); err != nil {
|
||||
if _, err := conn.WriteTo(fq.packet, ipp.UDPAddr()); err != nil {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -525,8 +541,8 @@ func (f *forwarder) forward(query packet) error {
|
||||
firstErr error
|
||||
)
|
||||
|
||||
for _, rr := range resolvers {
|
||||
go func(rr resolverAndDelay) {
|
||||
for i := range resolvers {
|
||||
go func(rr *resolverAndDelay) {
|
||||
if rr.startDelay > 0 {
|
||||
timer := time.NewTimer(rr.startDelay)
|
||||
select {
|
||||
@@ -536,7 +552,7 @@ func (f *forwarder) forward(query packet) error {
|
||||
return
|
||||
}
|
||||
}
|
||||
resb, err := f.send(ctx, fq, rr.ipp)
|
||||
resb, err := f.send(ctx, fq, *rr)
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
@@ -549,7 +565,7 @@ func (f *forwarder) forward(query packet) error {
|
||||
case resc <- resb:
|
||||
default:
|
||||
}
|
||||
}(rr)
|
||||
}(&resolvers[i])
|
||||
}
|
||||
|
||||
select {
|
||||
@@ -638,7 +654,7 @@ func (p *closePool) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var knownDoH = map[netaddr.IP]string{}
|
||||
var knownDoH = map[netaddr.IP]string{} // 8.8.8.8 => "https://..."
|
||||
|
||||
var dohIPsOfBase = map[string][]netaddr.IP{}
|
||||
|
||||
|
@@ -6,23 +6,28 @@ package resolver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/types/dnstype"
|
||||
)
|
||||
|
||||
func (rr resolverAndDelay) String() string {
|
||||
return fmt.Sprintf("%v+%v", rr.ipp, rr.startDelay)
|
||||
return fmt.Sprintf("%v+%v", rr.name, rr.startDelay)
|
||||
}
|
||||
|
||||
func TestResolversWithDelays(t *testing.T) {
|
||||
// query
|
||||
q := func(ss ...string) (ipps []netaddr.IPPort) {
|
||||
q := func(ss ...string) (ipps []dnstype.Resolver) {
|
||||
for _, s := range ss {
|
||||
ipps = append(ipps, netaddr.MustParseIPPort(s))
|
||||
host, _, err := net.SplitHostPort(s)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ipps = append(ipps, dnstype.Resolver{Addr: host})
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -38,8 +43,12 @@ func TestResolversWithDelays(t *testing.T) {
|
||||
}
|
||||
s = s[:i]
|
||||
}
|
||||
host, _, err := net.SplitHostPort(s)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rr = append(rr, resolverAndDelay{
|
||||
ipp: netaddr.MustParseIPPort(s),
|
||||
name: dnstype.Resolver{Addr: host},
|
||||
startDelay: d,
|
||||
})
|
||||
}
|
||||
@@ -48,7 +57,7 @@ func TestResolversWithDelays(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
in []netaddr.IPPort
|
||||
in []dnstype.Resolver
|
||||
want []resolverAndDelay
|
||||
}{
|
||||
{
|
||||
|
@@ -11,6 +11,7 @@ import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -20,6 +21,7 @@ import (
|
||||
|
||||
dns "golang.org/x/net/dns/dnsmessage"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/types/dnstype"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/dnsname"
|
||||
"tailscale.com/wgengine/monitor"
|
||||
@@ -73,7 +75,7 @@ type Config struct {
|
||||
// queries within that suffix.
|
||||
// Queries only match the most specific suffix.
|
||||
// To register a "default route", add an entry for ".".
|
||||
Routes map[dnsname.FQDN][]netaddr.IPPort
|
||||
Routes map[dnsname.FQDN][]dnstype.Resolver
|
||||
// LocalHosts is a map of FQDNs to corresponding IPs.
|
||||
Hosts map[dnsname.FQDN][]netaddr.IP
|
||||
// LocalDomains is a list of DNS name suffixes that should not be
|
||||
@@ -121,9 +123,35 @@ func WriteIPPorts(w *bufio.Writer, vv []netaddr.IPPort) {
|
||||
w.WriteByte(']')
|
||||
}
|
||||
|
||||
// WriteDNSResolver writes r to w.
|
||||
func WriteDNSResolver(w *bufio.Writer, r dnstype.Resolver) {
|
||||
io.WriteString(w, r.Addr)
|
||||
if len(r.BootstrapResolution) > 0 {
|
||||
w.WriteByte('(')
|
||||
var b []byte
|
||||
for _, ip := range r.BootstrapResolution {
|
||||
ip.AppendTo(b[:0])
|
||||
w.Write(b)
|
||||
}
|
||||
w.WriteByte(')')
|
||||
}
|
||||
}
|
||||
|
||||
// WriteDNSResolvers writes resolvers to w.
|
||||
func WriteDNSResolvers(w *bufio.Writer, resolvers []dnstype.Resolver) {
|
||||
w.WriteByte('[')
|
||||
for i, r := range resolvers {
|
||||
if i > 0 {
|
||||
w.WriteByte(' ')
|
||||
}
|
||||
WriteDNSResolver(w, r)
|
||||
}
|
||||
w.WriteByte(']')
|
||||
}
|
||||
|
||||
// WriteRoutes writes routes to w, omitting *.arpa routes and instead
|
||||
// summarizing how many of them there were.
|
||||
func WriteRoutes(w *bufio.Writer, routes map[dnsname.FQDN][]netaddr.IPPort) {
|
||||
func WriteRoutes(w *bufio.Writer, routes map[dnsname.FQDN][]dnstype.Resolver) {
|
||||
var kk []dnsname.FQDN
|
||||
arpa := 0
|
||||
for k := range routes {
|
||||
@@ -141,7 +169,7 @@ func WriteRoutes(w *bufio.Writer, routes map[dnsname.FQDN][]netaddr.IPPort) {
|
||||
}
|
||||
w.WriteString(string(k))
|
||||
w.WriteByte(':')
|
||||
WriteIPPorts(w, routes[k])
|
||||
WriteDNSResolvers(w, routes[k])
|
||||
}
|
||||
w.WriteByte('}')
|
||||
if arpa > 0 {
|
||||
|
@@ -19,6 +19,7 @@ import (
|
||||
dns "golang.org/x/net/dns/dnsmessage"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/tstest"
|
||||
"tailscale.com/types/dnstype"
|
||||
"tailscale.com/util/dnsname"
|
||||
"tailscale.com/wgengine/monitor"
|
||||
)
|
||||
@@ -466,10 +467,10 @@ func TestDelegate(t *testing.T) {
|
||||
defer r.Close()
|
||||
|
||||
cfg := dnsCfg
|
||||
cfg.Routes = map[dnsname.FQDN][]netaddr.IPPort{
|
||||
cfg.Routes = map[dnsname.FQDN][]dnstype.Resolver{
|
||||
".": {
|
||||
netaddr.MustParseIPPort(v4server.PacketConn.LocalAddr().String()),
|
||||
netaddr.MustParseIPPort(v6server.PacketConn.LocalAddr().String()),
|
||||
dnstype.Resolver{Addr: v4server.PacketConn.LocalAddr().String()},
|
||||
dnstype.Resolver{Addr: v6server.PacketConn.LocalAddr().String()},
|
||||
},
|
||||
}
|
||||
r.SetConfig(cfg)
|
||||
@@ -641,9 +642,9 @@ func TestDelegateSplitRoute(t *testing.T) {
|
||||
defer r.Close()
|
||||
|
||||
cfg := dnsCfg
|
||||
cfg.Routes = map[dnsname.FQDN][]netaddr.IPPort{
|
||||
".": {netaddr.MustParseIPPort(server1.PacketConn.LocalAddr().String())},
|
||||
"other.": {netaddr.MustParseIPPort(server2.PacketConn.LocalAddr().String())},
|
||||
cfg.Routes = map[dnsname.FQDN][]dnstype.Resolver{
|
||||
".": {{Addr: server1.PacketConn.LocalAddr().String()}},
|
||||
"other.": {{Addr: server2.PacketConn.LocalAddr().String()}},
|
||||
}
|
||||
r.SetConfig(cfg)
|
||||
|
||||
@@ -698,10 +699,8 @@ func TestDelegateCollision(t *testing.T) {
|
||||
defer r.Close()
|
||||
|
||||
cfg := dnsCfg
|
||||
cfg.Routes = map[dnsname.FQDN][]netaddr.IPPort{
|
||||
".": {
|
||||
netaddr.MustParseIPPort(server.PacketConn.LocalAddr().String()),
|
||||
},
|
||||
cfg.Routes = map[dnsname.FQDN][]dnstype.Resolver{
|
||||
".": {{Addr: server.PacketConn.LocalAddr().String()}},
|
||||
}
|
||||
r.SetConfig(cfg)
|
||||
|
||||
@@ -1005,10 +1004,8 @@ func BenchmarkFull(b *testing.B) {
|
||||
defer r.Close()
|
||||
|
||||
cfg := dnsCfg
|
||||
cfg.Routes = map[dnsname.FQDN][]netaddr.IPPort{
|
||||
".": {
|
||||
netaddr.MustParseIPPort(server.PacketConn.LocalAddr().String()),
|
||||
},
|
||||
cfg.Routes = map[dnsname.FQDN][]dnstype.Resolver{
|
||||
".": {{Addr: server.PacketConn.LocalAddr().String()}},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
|
Reference in New Issue
Block a user