mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-29 13:05:46 +00:00
2cff9016e4
I started to write a full DNS caching resolver and I realized it was overkill and wouldn't work on Windows even in Go 1.14 yet, so I'm doing this tiny one instead for now, just for all our netcheck STUN derp lookups, and connections to DERP servers. (This will be caching a exactly 8 DNS entries, all ours.) Fixes #145 (can be better later, of course)
152 lines
3.4 KiB
Go
152 lines
3.4 KiB
Go
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
// Package dnscache contains a minimal DNS cache that makes a bunch of
|
|
// assumptions that are only valid for us. Not recommended for general use.
|
|
package dnscache
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/sync/singleflight"
|
|
)
|
|
|
|
var single = &Resolver{
|
|
Forward: &net.Resolver{PreferGo: true},
|
|
}
|
|
|
|
// Get returns a caching Resolver singleton.
|
|
func Get() *Resolver { return single }
|
|
|
|
const fixedTTL = 10 * time.Minute
|
|
|
|
// Resolver is a minimal DNS caching resolver.
|
|
//
|
|
// The TTL is always fixed for now. It's not intended for general use.
|
|
// Cache entries are never cleaned up so it's intended that this is
|
|
// only used with a fixed set of hostnames.
|
|
type Resolver struct {
|
|
// Forward is the resolver to use to populate the cache.
|
|
// If nil, net.DefaultResolver is used.
|
|
Forward *net.Resolver
|
|
|
|
sf singleflight.Group
|
|
|
|
mu sync.Mutex
|
|
ipCache map[string]ipCacheEntry
|
|
}
|
|
|
|
type ipCacheEntry struct {
|
|
ip net.IP
|
|
expires time.Time
|
|
}
|
|
|
|
func (r *Resolver) fwd() *net.Resolver {
|
|
if r.Forward != nil {
|
|
return r.Forward
|
|
}
|
|
return net.DefaultResolver
|
|
}
|
|
|
|
// LookupIP returns the first IPv4 address found, otherwise the first IPv6 address.
|
|
func (r *Resolver) LookupIP(ctx context.Context, host string) (net.IP, error) {
|
|
if ip := net.ParseIP(host); ip != nil {
|
|
if ip4 := ip.To4(); ip4 != nil {
|
|
return ip4, nil
|
|
}
|
|
return ip, nil
|
|
}
|
|
|
|
if ip, ok := r.lookupIPCache(host); ok {
|
|
return ip, nil
|
|
}
|
|
|
|
ch := r.sf.DoChan(host, func() (interface{}, error) {
|
|
ip, err := r.lookupIP(host)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return ip, nil
|
|
})
|
|
select {
|
|
case res := <-ch:
|
|
if res.Err != nil {
|
|
return nil, res.Err
|
|
}
|
|
return res.Val.(net.IP), nil
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
}
|
|
}
|
|
|
|
func (r *Resolver) lookupIPCache(host string) (ip net.IP, ok bool) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
if ent, ok := r.ipCache[host]; ok && ent.expires.After(time.Now()) {
|
|
return ent.ip, true
|
|
}
|
|
return nil, false
|
|
}
|
|
|
|
func (r *Resolver) lookupIP(host string) (net.IP, error) {
|
|
if ip, ok := r.lookupIPCache(host); ok {
|
|
return ip, nil
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
ips, err := r.fwd().LookupIPAddr(ctx, host)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(ips) == 0 {
|
|
return nil, fmt.Errorf("no IPs for %q found", host)
|
|
}
|
|
|
|
for _, ipa := range ips {
|
|
if ip4 := ipa.IP.To4(); ip4 != nil {
|
|
return r.addIPCache(host, ip4, fixedTTL), nil
|
|
}
|
|
}
|
|
return r.addIPCache(host, ips[0].IP, fixedTTL), nil
|
|
}
|
|
|
|
func (r *Resolver) addIPCache(host string, ip net.IP, d time.Duration) net.IP {
|
|
if isPrivateIP(ip) {
|
|
// Don't cache obviously wrong entries from captive portals.
|
|
// TODO: use DoH or DoT for the forwarding resolver?
|
|
return ip
|
|
}
|
|
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
if r.ipCache == nil {
|
|
r.ipCache = make(map[string]ipCacheEntry)
|
|
}
|
|
r.ipCache[host] = ipCacheEntry{ip: ip, expires: time.Now().Add(d)}
|
|
return ip
|
|
}
|
|
|
|
func mustCIDR(s string) *net.IPNet {
|
|
_, ipNet, err := net.ParseCIDR("100.64.0.0/10")
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return ipNet
|
|
}
|
|
|
|
func isPrivateIP(ip net.IP) bool {
|
|
return private1.Contains(ip) || private2.Contains(ip) || private3.Contains(ip)
|
|
}
|
|
|
|
var (
|
|
private1 = mustCIDR("10.0.0.0/8")
|
|
private2 = mustCIDR("172.16.0.0/12")
|
|
private3 = mustCIDR("192.168.0.0/16")
|
|
)
|