From 2cff9016e4490a80eb99dbba13684ff7ae9c340b Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Thu, 5 Mar 2020 10:29:19 -0800 Subject: [PATCH] net/dnscache: add overly simplistic DNS cache package for selective use I started to write a full DNS caching resolver and I realized it was overkill and wouldn't work on Windows even in Go 1.14 yet, so I'm doing this tiny one instead for now, just for all our netcheck STUN derp lookups, and connections to DERP servers. (This will be caching a exactly 8 DNS entries, all ours.) Fixes #145 (can be better later, of course) --- derp/derphttp/derphttp_client.go | 22 ++++- net/dnscache/dnscache.go | 151 +++++++++++++++++++++++++++++++ netcheck/netcheck.go | 4 + stunner/stunner.go | 25 +++-- wgengine/magicsock/magicsock.go | 2 + 5 files changed, 191 insertions(+), 13 deletions(-) create mode 100644 net/dnscache/dnscache.go diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go index 2d1a821d7..cfe6d0225 100644 --- a/derp/derphttp/derphttp_client.go +++ b/derp/derphttp/derphttp_client.go @@ -26,6 +26,7 @@ "time" "tailscale.com/derp" + "tailscale.com/net/dnscache" "tailscale.com/types/key" "tailscale.com/types/logger" ) @@ -37,7 +38,8 @@ // Send/Recv will completely re-establish the connection (unless Close // has been called). type Client struct { - TLSConfig *tls.Config // for sever connection, optional, nil means default + TLSConfig *tls.Config // for sever connection, optional, nil means default + DNSCache *dnscache.Resolver // optional; if nil, no caching privateKey key.Private logf logger.Logf @@ -137,11 +139,23 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien } }() + host := c.url.Hostname() + hostOrIP := host + var d net.Dialer - log.Printf("Dialing: %q", net.JoinHostPort(c.url.Hostname(), urlPort(c.url))) - tcpConn, err = d.DialContext(ctx, "tcp", net.JoinHostPort(c.url.Hostname(), urlPort(c.url))) + log.Printf("Dialing: %q", net.JoinHostPort(host, urlPort(c.url))) + + if c.DNSCache != nil { + ip, err := c.DNSCache.LookupIP(ctx, host) + if err != nil { + return nil, err + } + hostOrIP = ip.String() + } + + tcpConn, err = d.DialContext(ctx, "tcp", net.JoinHostPort(hostOrIP, urlPort(c.url))) if err != nil { - return nil, err + return nil, fmt.Errorf("Dial of %q: %v", host, err) } // Now that we have a TCP connection, force close it. diff --git a/net/dnscache/dnscache.go b/net/dnscache/dnscache.go new file mode 100644 index 000000000..b435e5aaa --- /dev/null +++ b/net/dnscache/dnscache.go @@ -0,0 +1,151 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package dnscache contains a minimal DNS cache that makes a bunch of +// assumptions that are only valid for us. Not recommended for general use. +package dnscache + +import ( + "context" + "fmt" + "net" + "sync" + "time" + + "golang.org/x/sync/singleflight" +) + +var single = &Resolver{ + Forward: &net.Resolver{PreferGo: true}, +} + +// Get returns a caching Resolver singleton. +func Get() *Resolver { return single } + +const fixedTTL = 10 * time.Minute + +// Resolver is a minimal DNS caching resolver. +// +// The TTL is always fixed for now. It's not intended for general use. +// Cache entries are never cleaned up so it's intended that this is +// only used with a fixed set of hostnames. +type Resolver struct { + // Forward is the resolver to use to populate the cache. + // If nil, net.DefaultResolver is used. + Forward *net.Resolver + + sf singleflight.Group + + mu sync.Mutex + ipCache map[string]ipCacheEntry +} + +type ipCacheEntry struct { + ip net.IP + expires time.Time +} + +func (r *Resolver) fwd() *net.Resolver { + if r.Forward != nil { + return r.Forward + } + return net.DefaultResolver +} + +// LookupIP returns the first IPv4 address found, otherwise the first IPv6 address. +func (r *Resolver) LookupIP(ctx context.Context, host string) (net.IP, error) { + if ip := net.ParseIP(host); ip != nil { + if ip4 := ip.To4(); ip4 != nil { + return ip4, nil + } + return ip, nil + } + + if ip, ok := r.lookupIPCache(host); ok { + return ip, nil + } + + ch := r.sf.DoChan(host, func() (interface{}, error) { + ip, err := r.lookupIP(host) + if err != nil { + return nil, err + } + return ip, nil + }) + select { + case res := <-ch: + if res.Err != nil { + return nil, res.Err + } + return res.Val.(net.IP), nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (r *Resolver) lookupIPCache(host string) (ip net.IP, ok bool) { + r.mu.Lock() + defer r.mu.Unlock() + if ent, ok := r.ipCache[host]; ok && ent.expires.After(time.Now()) { + return ent.ip, true + } + return nil, false +} + +func (r *Resolver) lookupIP(host string) (net.IP, error) { + if ip, ok := r.lookupIPCache(host); ok { + return ip, nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + ips, err := r.fwd().LookupIPAddr(ctx, host) + if err != nil { + return nil, err + } + if len(ips) == 0 { + return nil, fmt.Errorf("no IPs for %q found", host) + } + + for _, ipa := range ips { + if ip4 := ipa.IP.To4(); ip4 != nil { + return r.addIPCache(host, ip4, fixedTTL), nil + } + } + return r.addIPCache(host, ips[0].IP, fixedTTL), nil +} + +func (r *Resolver) addIPCache(host string, ip net.IP, d time.Duration) net.IP { + if isPrivateIP(ip) { + // Don't cache obviously wrong entries from captive portals. + // TODO: use DoH or DoT for the forwarding resolver? + return ip + } + + r.mu.Lock() + defer r.mu.Unlock() + if r.ipCache == nil { + r.ipCache = make(map[string]ipCacheEntry) + } + r.ipCache[host] = ipCacheEntry{ip: ip, expires: time.Now().Add(d)} + return ip +} + +func mustCIDR(s string) *net.IPNet { + _, ipNet, err := net.ParseCIDR("100.64.0.0/10") + if err != nil { + panic(err) + } + return ipNet +} + +func isPrivateIP(ip net.IP) bool { + return private1.Contains(ip) || private2.Contains(ip) || private3.Contains(ip) +} + +var ( + private1 = mustCIDR("10.0.0.0/8") + private2 = mustCIDR("172.16.0.0/12") + private3 = mustCIDR("192.168.0.0/16") +) diff --git a/netcheck/netcheck.go b/netcheck/netcheck.go index 34d57f450..3bd314d71 100644 --- a/netcheck/netcheck.go +++ b/netcheck/netcheck.go @@ -17,6 +17,7 @@ "golang.org/x/sync/errgroup" "tailscale.com/interfaces" + "tailscale.com/net/dnscache" "tailscale.com/stun" "tailscale.com/stunner" "tailscale.com/types/logger" @@ -181,6 +182,7 @@ func GetReport(ctx context.Context, logf logger.Logf) (*Report, error) { Endpoint: add, Servers: stunServers, Logf: logf, + DNSCache: dnscache.Get(), } grp.Go(func() error { return s4.Run(ctx) }) go reader(s4, pc4, unlimited) @@ -190,6 +192,7 @@ func GetReport(ctx context.Context, logf logger.Logf) (*Report, error) { Endpoint: addHair, Servers: stunServers, Logf: logf, + DNSCache: dnscache.Get(), } grp.Go(func() error { return s4Hair.Run(ctx) }) go reader(s4Hair, pc4Hair, 2) @@ -201,6 +204,7 @@ func GetReport(ctx context.Context, logf logger.Logf) (*Report, error) { Servers: stunServers6, Logf: logf, OnlyIPv6: true, + DNSCache: dnscache.Get(), } grp.Go(func() error { return s6.Run(ctx) }) go reader(s6, pc6, unlimited) diff --git a/stunner/stunner.go b/stunner/stunner.go index 6d2c0e4e4..54fa1360c 100644 --- a/stunner/stunner.go +++ b/stunner/stunner.go @@ -12,6 +12,7 @@ "sync" "time" + "tailscale.com/net/dnscache" "tailscale.com/stun" ) @@ -38,9 +39,9 @@ type Stunner struct { Servers []string // STUN servers to contact - // Resolver optionally specifies a resolver to use for DNS lookups. - // If nil, net.DefaultResolver is used. - Resolver *net.Resolver + // DNSCache optionally specifies a DNSCache to use. + // If nil, a DNS cache is not used. + DNSCache *dnscache.Resolver // Logf optionally specifies a log function. If nil, logging is disabled. Logf func(format string, args ...interface{}) @@ -118,9 +119,6 @@ func (s *Stunner) Receive(p []byte, fromAddr *net.UDPAddr) { } func (s *Stunner) resolver() *net.Resolver { - if s.Resolver != nil { - return s.Resolver - } return net.DefaultResolver } @@ -192,9 +190,18 @@ func (s *Stunner) sendSTUN(ctx context.Context, server string) error { } addr := &net.UDPAddr{Port: addrPort} - ipAddrs, err := s.resolver().LookupIPAddr(ctx, host) - if err != nil { - return fmt.Errorf("lookup ip addr: %v", err) + var ipAddrs []net.IPAddr + if s.DNSCache != nil { + ip, err := s.DNSCache.LookupIP(ctx, host) + if err != nil { + return fmt.Errorf("lookup ip addr: %v", err) + } + ipAddrs = []net.IPAddr{{IP: ip}} + } else { + ipAddrs, err = s.resolver().LookupIPAddr(ctx, host) + if err != nil { + return fmt.Errorf("lookup ip addr: %v", err) + } } for _, ipAddr := range ipAddrs { ip4 := ipAddr.IP.To4() diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 20fe7073e..b54b0774d 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -31,6 +31,7 @@ "tailscale.com/derp" "tailscale.com/derp/derphttp" "tailscale.com/interfaces" + "tailscale.com/net/dnscache" "tailscale.com/netcheck" "tailscale.com/stun" "tailscale.com/stunner" @@ -638,6 +639,7 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr) chan<- derpWriteRequest { c.logf("derphttp.NewClient: port %d, host %q invalid? err: %v", addr.Port, host, err) return nil } + dc.DNSCache = dnscache.Get() dc.TLSConfig = c.derpTLSConfig ctx, cancel := context.WithCancel(context.Background())