// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package resolver import ( "bytes" "context" "encoding/binary" "errors" "fmt" "hash/crc32" "math/rand" "net" "os" "sync" "time" dns "golang.org/x/net/dns/dnsmessage" "inet.af/netaddr" "tailscale.com/logtail/backoff" "tailscale.com/net/netns" "tailscale.com/types/logger" "tailscale.com/util/dnsname" ) // headerBytes is the number of bytes in a DNS message header. const headerBytes = 12 // connCount is the number of UDP connections to use for forwarding. const connCount = 32 const ( // cleanupInterval is the interval between purged of timed-out entries from txMap. cleanupInterval = 30 * time.Second // responseTimeout is the maximal amount of time to wait for a DNS response. responseTimeout = 5 * time.Second ) var errNoUpstreams = errors.New("upstream nameservers not set") var aLongTimeAgo = time.Unix(0, 1) type forwardingRecord struct { src netaddr.IPPort createdAt time.Time } // txid identifies a DNS transaction. // // As the standard DNS Request ID is only 16 bits, we extend it: // the lower 32 bits are the zero-extended bits of the DNS Request ID; // the upper 32 bits are the CRC32 checksum of the first question in the request. // This makes probability of txid collision negligible. type txid uint64 // getTxID computes the txid of the given DNS packet. func getTxID(packet []byte) txid { if len(packet) < headerBytes { return 0 } dnsid := binary.BigEndian.Uint16(packet[0:2]) qcount := binary.BigEndian.Uint16(packet[4:6]) if qcount == 0 { return txid(dnsid) } offset := headerBytes for i := uint16(0); i < qcount; i++ { // Note: this relies on the fact that names are not compressed in questions, // so they are guaranteed to end with a NUL byte. // // Justification: // RFC 1035 doesn't seem to explicitly prohibit compressing names in questions, // but this is exceedingly unlikely to be done in practice. A DNS request // with multiple questions is ill-defined (which questions do the header flags apply to?) // and a single question would have to contain a pointer to an *answer*, // which would be excessively smart, pointless (an answer can just as well refer to the question) // and perhaps even prohibited: a draft RFC (draft-ietf-dnsind-local-compression-05) states: // // > It is important that these pointers always point backwards. // // This is said in summarizing RFC 1035, although that phrase does not appear in the original RFC. // Additionally, (https://cr.yp.to/djbdns/notes.html) states: // // > The precise rule is that a name can be compressed if it is a response owner name, // > the name in NS data, the name in CNAME data, the name in PTR data, the name in MX data, // > or one of the names in SOA data. namebytes := bytes.IndexByte(packet[offset:], 0) // ... | name | NUL | type | class // ?? 1 2 2 offset = offset + namebytes + 5 if len(packet) < offset { // Corrupt packet; don't crash. return txid(dnsid) } } hash := crc32.ChecksumIEEE(packet[headerBytes:offset]) return (txid(hash) << 32) | txid(dnsid) } type route struct { suffix string resolvers []netaddr.IPPort } // forwarder forwards DNS packets to a number of upstream nameservers. type forwarder struct { logf logger.Logf // responses is a channel by which responses are returned. responses chan packet // closed signals all goroutines to stop. closed chan struct{} // wg signals when all goroutines have stopped. wg sync.WaitGroup // conns are the UDP connections used for forwarding. // A random one is selected for each request, regardless of the target upstream. conns []*fwdConn mu sync.Mutex // routes are per-suffix resolvers to use. routes []route // most specific routes first txMap map[txid]forwardingRecord // txids to in-flight requests } func init() { rand.Seed(time.Now().UnixNano()) } func newForwarder(logf logger.Logf, responses chan packet) *forwarder { ret := &forwarder{ logf: logger.WithPrefix(logf, "forward: "), responses: responses, closed: make(chan struct{}), conns: make([]*fwdConn, connCount), txMap: make(map[txid]forwardingRecord), } ret.wg.Add(connCount + 1) for idx := range ret.conns { ret.conns[idx] = newFwdConn(ret.logf, idx) go ret.recv(ret.conns[idx]) } go ret.cleanMap() return ret } func (f *forwarder) Close() { select { case <-f.closed: return default: // continue } close(f.closed) for _, conn := range f.conns { conn.close() } f.wg.Wait() } func (f *forwarder) rebindFromNetworkChange() { for _, c := range f.conns { c.mu.Lock() c.reconnectLocked() c.mu.Unlock() } } func (f *forwarder) setRoutes(routes []route) { f.mu.Lock() f.routes = routes f.mu.Unlock() } // send sends packet to dst. It is best effort. func (f *forwarder) send(packet []byte, dst netaddr.IPPort) { connIdx := rand.Intn(connCount) conn := f.conns[connIdx] conn.send(packet, dst) } func (f *forwarder) recv(conn *fwdConn) { defer f.wg.Done() for { select { case <-f.closed: return default: } out := make([]byte, maxResponseBytes) n := conn.read(out) if n == 0 { continue } if n < headerBytes { f.logf("recv: packet too small (%d bytes)", n) } out = out[:n] txid := getTxID(out) f.mu.Lock() record, found := f.txMap[txid] // At most one nameserver will return a response: // the first one to do so will delete txid from the map. if !found { f.mu.Unlock() continue } delete(f.txMap, txid) f.mu.Unlock() pkt := packet{out, record.src} select { case <-f.closed: return case f.responses <- pkt: // continue } } } // cleanMap periodically deletes timed-out forwarding records from f.txMap to bound growth. func (f *forwarder) cleanMap() { defer f.wg.Done() t := time.NewTicker(cleanupInterval) defer t.Stop() var now time.Time for { select { case <-f.closed: return case now = <-t.C: // continue } f.mu.Lock() for k, v := range f.txMap { if now.Sub(v.createdAt) > responseTimeout { delete(f.txMap, k) } } f.mu.Unlock() } } // forward forwards the query to all upstream nameservers and returns the first response. func (f *forwarder) forward(query packet) error { domain, err := nameFromQuery(query.bs) if err != nil { return err } txid := getTxID(query.bs) f.mu.Lock() routes := f.routes f.mu.Unlock() var resolvers []netaddr.IPPort for _, route := range routes { if route.suffix != "." && !dnsname.HasSuffix(domain, route.suffix) { continue } resolvers = route.resolvers break } if len(resolvers) == 0 { return errNoUpstreams } f.mu.Lock() f.txMap[txid] = forwardingRecord{ src: query.addr, createdAt: time.Now(), } f.mu.Unlock() for _, resolver := range resolvers { f.send(query.bs, resolver) } return nil } // A fwdConn manages a single connection used to forward DNS requests. // Net link changes can cause a *net.UDPConn to become permanently unusable, particularly on macOS. // fwdConn detects such situations and transparently creates new connections. type fwdConn struct { // logf allows a fwdConn to log. logf logger.Logf // wg tracks the number of outstanding conn.Read and conn.Write calls. wg sync.WaitGroup // change allows calls to read to block until a the network connection has been replaced. change *sync.Cond // mu protects fields that follow it; it is also change's Locker. mu sync.Mutex // closed tracks whether fwdConn has been permanently closed. closed bool // conn is the current active connection. conn net.PacketConn } func newFwdConn(logf logger.Logf, idx int) *fwdConn { c := new(fwdConn) c.logf = logger.WithPrefix(logf, fmt.Sprintf("fwdConn %d: ", idx)) c.change = sync.NewCond(&c.mu) // c.conn is created lazily in send return c } // send sends packet to dst using c's connection. // It is best effort. It is UDP, after all. Failures are logged. func (c *fwdConn) send(packet []byte, dst netaddr.IPPort) { var b *backoff.Backoff // lazily initialized, since it is not needed in the common case backOff := func(err error) { if b == nil { b = backoff.NewBackoff("dns-fwdConn-send", c.logf, 30*time.Second) } b.BackOff(context.Background(), err) } for { // Gather the current connection. // We can't hold the lock while we call WriteTo. c.mu.Lock() conn := c.conn closed := c.closed if closed { c.mu.Unlock() return } if conn == nil { c.reconnectLocked() c.mu.Unlock() continue } c.mu.Unlock() a := dst.UDPAddr() c.wg.Add(1) _, err := conn.WriteTo(packet, a) c.wg.Done() if err == nil { // Success return } if errors.Is(err, os.ErrDeadlineExceeded) { // We intentionally closed this connection. // It has been replaced by a new connection. Try again. continue } // Something else went wrong. // We have three choices here: try again, give up, or create a new connection. var opErr *net.OpError if !errors.As(err, &opErr) { // Weird. All errors from the net package should be *net.OpError. Bail. c.logf("send: non-*net.OpErr %v (%T)", err, err) return } if opErr.Temporary() || opErr.Timeout() { // I doubt that either of these can happen (this is UDP), // but go ahead and try again. backOff(err) continue } if networkIsDown(err) { // Fail. c.logf("send: network is down") return } if networkIsUnreachable(err) { // This can be caused by a link change. // Replace the existing connection with a new one. c.mu.Lock() // It's possible that multiple senders discovered simultaneously // that the network is unreachable. Avoid reconnecting multiple times: // Only reconnect if the current connection is the one that we // discovered to be problematic. if c.conn == conn { backOff(err) c.reconnectLocked() } c.mu.Unlock() // Try again with our new network connection. continue } // Unrecognized error. Fail. c.logf("send: unrecognized error: %v", err) return } } // read waits for a response from c's connection. // It returns the number of bytes read, which may be 0 // in case of an error or a closed connection. func (c *fwdConn) read(out []byte) int { for { // Gather the current connection. // We can't hold the lock while we call ReadFrom. c.mu.Lock() conn := c.conn closed := c.closed if closed { c.mu.Unlock() return 0 } if conn == nil { // There is no current connection. // Wait for the connection to change, then try again. c.change.Wait() c.mu.Unlock() continue } c.mu.Unlock() c.wg.Add(1) n, _, err := conn.ReadFrom(out) c.wg.Done() if err == nil { // Success. return n } if errors.Is(err, os.ErrDeadlineExceeded) { // We intentionally closed this connection. // It has been replaced by a new connection. Try again. continue } c.logf("read: unrecognized error: %v", err) return 0 } } // reconnectLocked replaces the current connection with a new one. // c.mu must be locked. func (c *fwdConn) reconnectLocked() { c.closeConnLocked() // Make a new connection. conn, err := netns.Listener().ListenPacket(context.Background(), "udp", "") if err != nil { c.logf("ListenPacket failed: %v", err) } else { c.conn = conn } // Broadcast that a new connection is available. c.change.Broadcast() } // closeCurrentConn closes the current connection. // c.mu must be locked. func (c *fwdConn) closeConnLocked() { if c.conn == nil { return } // Unblock all readers/writers, wait for them, close the connection. c.conn.SetDeadline(aLongTimeAgo) c.wg.Wait() c.conn.Close() c.conn = nil } // close permanently closes c. func (c *fwdConn) close() { c.mu.Lock() defer c.mu.Unlock() if c.closed { return } c.closed = true c.closeConnLocked() // Unblock any remaining readers. c.change.Broadcast() } // nameFromQuery extracts the normalized query name from bs. func nameFromQuery(bs []byte) (string, error) { var parser dns.Parser hdr, err := parser.Start(bs) if err != nil { return "", err } if hdr.Response { return "", errNotQuery } q, err := parser.Question() if err != nil { return "", err } n := q.Name.Data[:q.Name.Length] return rawNameToLower(n), nil }