diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go index 2d1a821d7..cfe6d0225 100644 --- a/derp/derphttp/derphttp_client.go +++ b/derp/derphttp/derphttp_client.go @@ -26,6 +26,7 @@ import ( "time" "tailscale.com/derp" + "tailscale.com/net/dnscache" "tailscale.com/types/key" "tailscale.com/types/logger" ) @@ -37,7 +38,8 @@ import ( // Send/Recv will completely re-establish the connection (unless Close // has been called). type Client struct { - TLSConfig *tls.Config // for sever connection, optional, nil means default + TLSConfig *tls.Config // for sever connection, optional, nil means default + DNSCache *dnscache.Resolver // optional; if nil, no caching privateKey key.Private logf logger.Logf @@ -137,11 +139,23 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien } }() + host := c.url.Hostname() + hostOrIP := host + var d net.Dialer - log.Printf("Dialing: %q", net.JoinHostPort(c.url.Hostname(), urlPort(c.url))) - tcpConn, err = d.DialContext(ctx, "tcp", net.JoinHostPort(c.url.Hostname(), urlPort(c.url))) + log.Printf("Dialing: %q", net.JoinHostPort(host, urlPort(c.url))) + + if c.DNSCache != nil { + ip, err := c.DNSCache.LookupIP(ctx, host) + if err != nil { + return nil, err + } + hostOrIP = ip.String() + } + + tcpConn, err = d.DialContext(ctx, "tcp", net.JoinHostPort(hostOrIP, urlPort(c.url))) if err != nil { - return nil, err + return nil, fmt.Errorf("Dial of %q: %v", host, err) } // Now that we have a TCP connection, force close it. diff --git a/net/dnscache/dnscache.go b/net/dnscache/dnscache.go new file mode 100644 index 000000000..b435e5aaa --- /dev/null +++ b/net/dnscache/dnscache.go @@ -0,0 +1,151 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package dnscache contains a minimal DNS cache that makes a bunch of +// assumptions that are only valid for us. Not recommended for general use. +package dnscache + +import ( + "context" + "fmt" + "net" + "sync" + "time" + + "golang.org/x/sync/singleflight" +) + +var single = &Resolver{ + Forward: &net.Resolver{PreferGo: true}, +} + +// Get returns a caching Resolver singleton. +func Get() *Resolver { return single } + +const fixedTTL = 10 * time.Minute + +// Resolver is a minimal DNS caching resolver. +// +// The TTL is always fixed for now. It's not intended for general use. +// Cache entries are never cleaned up so it's intended that this is +// only used with a fixed set of hostnames. +type Resolver struct { + // Forward is the resolver to use to populate the cache. + // If nil, net.DefaultResolver is used. + Forward *net.Resolver + + sf singleflight.Group + + mu sync.Mutex + ipCache map[string]ipCacheEntry +} + +type ipCacheEntry struct { + ip net.IP + expires time.Time +} + +func (r *Resolver) fwd() *net.Resolver { + if r.Forward != nil { + return r.Forward + } + return net.DefaultResolver +} + +// LookupIP returns the first IPv4 address found, otherwise the first IPv6 address. +func (r *Resolver) LookupIP(ctx context.Context, host string) (net.IP, error) { + if ip := net.ParseIP(host); ip != nil { + if ip4 := ip.To4(); ip4 != nil { + return ip4, nil + } + return ip, nil + } + + if ip, ok := r.lookupIPCache(host); ok { + return ip, nil + } + + ch := r.sf.DoChan(host, func() (interface{}, error) { + ip, err := r.lookupIP(host) + if err != nil { + return nil, err + } + return ip, nil + }) + select { + case res := <-ch: + if res.Err != nil { + return nil, res.Err + } + return res.Val.(net.IP), nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (r *Resolver) lookupIPCache(host string) (ip net.IP, ok bool) { + r.mu.Lock() + defer r.mu.Unlock() + if ent, ok := r.ipCache[host]; ok && ent.expires.After(time.Now()) { + return ent.ip, true + } + return nil, false +} + +func (r *Resolver) lookupIP(host string) (net.IP, error) { + if ip, ok := r.lookupIPCache(host); ok { + return ip, nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + ips, err := r.fwd().LookupIPAddr(ctx, host) + if err != nil { + return nil, err + } + if len(ips) == 0 { + return nil, fmt.Errorf("no IPs for %q found", host) + } + + for _, ipa := range ips { + if ip4 := ipa.IP.To4(); ip4 != nil { + return r.addIPCache(host, ip4, fixedTTL), nil + } + } + return r.addIPCache(host, ips[0].IP, fixedTTL), nil +} + +func (r *Resolver) addIPCache(host string, ip net.IP, d time.Duration) net.IP { + if isPrivateIP(ip) { + // Don't cache obviously wrong entries from captive portals. + // TODO: use DoH or DoT for the forwarding resolver? + return ip + } + + r.mu.Lock() + defer r.mu.Unlock() + if r.ipCache == nil { + r.ipCache = make(map[string]ipCacheEntry) + } + r.ipCache[host] = ipCacheEntry{ip: ip, expires: time.Now().Add(d)} + return ip +} + +func mustCIDR(s string) *net.IPNet { + _, ipNet, err := net.ParseCIDR("100.64.0.0/10") + if err != nil { + panic(err) + } + return ipNet +} + +func isPrivateIP(ip net.IP) bool { + return private1.Contains(ip) || private2.Contains(ip) || private3.Contains(ip) +} + +var ( + private1 = mustCIDR("10.0.0.0/8") + private2 = mustCIDR("172.16.0.0/12") + private3 = mustCIDR("192.168.0.0/16") +) diff --git a/netcheck/netcheck.go b/netcheck/netcheck.go index 34d57f450..3bd314d71 100644 --- a/netcheck/netcheck.go +++ b/netcheck/netcheck.go @@ -17,6 +17,7 @@ import ( "golang.org/x/sync/errgroup" "tailscale.com/interfaces" + "tailscale.com/net/dnscache" "tailscale.com/stun" "tailscale.com/stunner" "tailscale.com/types/logger" @@ -181,6 +182,7 @@ func GetReport(ctx context.Context, logf logger.Logf) (*Report, error) { Endpoint: add, Servers: stunServers, Logf: logf, + DNSCache: dnscache.Get(), } grp.Go(func() error { return s4.Run(ctx) }) go reader(s4, pc4, unlimited) @@ -190,6 +192,7 @@ func GetReport(ctx context.Context, logf logger.Logf) (*Report, error) { Endpoint: addHair, Servers: stunServers, Logf: logf, + DNSCache: dnscache.Get(), } grp.Go(func() error { return s4Hair.Run(ctx) }) go reader(s4Hair, pc4Hair, 2) @@ -201,6 +204,7 @@ func GetReport(ctx context.Context, logf logger.Logf) (*Report, error) { Servers: stunServers6, Logf: logf, OnlyIPv6: true, + DNSCache: dnscache.Get(), } grp.Go(func() error { return s6.Run(ctx) }) go reader(s6, pc6, unlimited) diff --git a/stunner/stunner.go b/stunner/stunner.go index 6d2c0e4e4..54fa1360c 100644 --- a/stunner/stunner.go +++ b/stunner/stunner.go @@ -12,6 +12,7 @@ import ( "sync" "time" + "tailscale.com/net/dnscache" "tailscale.com/stun" ) @@ -38,9 +39,9 @@ type Stunner struct { Servers []string // STUN servers to contact - // Resolver optionally specifies a resolver to use for DNS lookups. - // If nil, net.DefaultResolver is used. - Resolver *net.Resolver + // DNSCache optionally specifies a DNSCache to use. + // If nil, a DNS cache is not used. + DNSCache *dnscache.Resolver // Logf optionally specifies a log function. If nil, logging is disabled. Logf func(format string, args ...interface{}) @@ -118,9 +119,6 @@ func (s *Stunner) Receive(p []byte, fromAddr *net.UDPAddr) { } func (s *Stunner) resolver() *net.Resolver { - if s.Resolver != nil { - return s.Resolver - } return net.DefaultResolver } @@ -192,9 +190,18 @@ func (s *Stunner) sendSTUN(ctx context.Context, server string) error { } addr := &net.UDPAddr{Port: addrPort} - ipAddrs, err := s.resolver().LookupIPAddr(ctx, host) - if err != nil { - return fmt.Errorf("lookup ip addr: %v", err) + var ipAddrs []net.IPAddr + if s.DNSCache != nil { + ip, err := s.DNSCache.LookupIP(ctx, host) + if err != nil { + return fmt.Errorf("lookup ip addr: %v", err) + } + ipAddrs = []net.IPAddr{{IP: ip}} + } else { + ipAddrs, err = s.resolver().LookupIPAddr(ctx, host) + if err != nil { + return fmt.Errorf("lookup ip addr: %v", err) + } } for _, ipAddr := range ipAddrs { ip4 := ipAddr.IP.To4() diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 20fe7073e..b54b0774d 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -31,6 +31,7 @@ import ( "tailscale.com/derp" "tailscale.com/derp/derphttp" "tailscale.com/interfaces" + "tailscale.com/net/dnscache" "tailscale.com/netcheck" "tailscale.com/stun" "tailscale.com/stunner" @@ -638,6 +639,7 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr) chan<- derpWriteRequest { c.logf("derphttp.NewClient: port %d, host %q invalid? err: %v", addr.Port, host, err) return nil } + dc.DNSCache = dnscache.Get() dc.TLSConfig = c.derpTLSConfig ctx, cancel := context.WithCancel(context.Background())