// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package resolver

import (
	"bytes"
	"context"
	"encoding/binary"
	"errors"
	"hash/crc32"
	"io"
	"math/rand"
	"net"
	"sync"
	"time"

	dns "golang.org/x/net/dns/dnsmessage"
	"inet.af/netaddr"
	"tailscale.com/types/logger"
	"tailscale.com/util/dnsname"
	"tailscale.com/wgengine/monitor"
)

// headerBytes is the number of bytes in a DNS message header.
const headerBytes = 12

const (
	// responseTimeout is the maximal amount of time to wait for a DNS response.
	responseTimeout = 5 * time.Second
)

var errNoUpstreams = errors.New("upstream nameservers not set")

// txid identifies a DNS transaction.
//
// As the standard DNS Request ID is only 16 bits, we extend it:
// the lower 32 bits are the zero-extended bits of the DNS Request ID;
// the upper 32 bits are the CRC32 checksum of the first question in the request.
// This makes probability of txid collision negligible.
type txid uint64

// getTxID computes the txid of the given DNS packet.
func getTxID(packet []byte) txid {
	if len(packet) < headerBytes {
		return 0
	}

	dnsid := binary.BigEndian.Uint16(packet[0:2])
	qcount := binary.BigEndian.Uint16(packet[4:6])
	if qcount == 0 {
		return txid(dnsid)
	}

	offset := headerBytes
	for i := uint16(0); i < qcount; i++ {
		// Note: this relies on the fact that names are not compressed in questions,
		// so they are guaranteed to end with a NUL byte.
		//
		// Justification:
		// RFC 1035 doesn't seem to explicitly prohibit compressing names in questions,
		// but this is exceedingly unlikely to be done in practice. A DNS request
		// with multiple questions is ill-defined (which questions do the header flags apply to?)
		// and a single question would have to contain a pointer to an *answer*,
		// which would be excessively smart, pointless (an answer can just as well refer to the question)
		// and perhaps even prohibited: a draft RFC (draft-ietf-dnsind-local-compression-05) states:
		//
		// > It is important that these pointers always point backwards.
		//
		// This is said in summarizing RFC 1035, although that phrase does not appear in the original RFC.
		// Additionally, (https://cr.yp.to/djbdns/notes.html) states:
		//
		// > The precise rule is that a name can be compressed if it is a response owner name,
		// > the name in NS data, the name in CNAME data, the name in PTR data, the name in MX data,
		// > or one of the names in SOA data.
		namebytes := bytes.IndexByte(packet[offset:], 0)
		// ... | name | NUL | type | class
		//        ??     1      2      2
		offset = offset + namebytes + 5
		if len(packet) < offset {
			// Corrupt packet; don't crash.
			return txid(dnsid)
		}
	}

	hash := crc32.ChecksumIEEE(packet[headerBytes:offset])
	return (txid(hash) << 32) | txid(dnsid)
}

// clampEDNSSize attempts to limit the maximum EDNS response size. This is not
// an exhaustive solution, instead only easy cases are currently handled in the
// interest of speed and reduced complexity. Only OPT records at the very end of
// the message with no option codes are addressed.
// TODO: handle more situations if we discover that they happen often
func clampEDNSSize(packet []byte, maxSize uint16) {
	// optFixedBytes is the size of an OPT record with no option codes.
	const optFixedBytes = 11
	const edns0Version = 0

	if len(packet) < headerBytes+optFixedBytes {
		return
	}

	arCount := binary.BigEndian.Uint16(packet[10:12])
	if arCount == 0 {
		// OPT shows up in an AR, so there must be no OPT
		return
	}

	opt := packet[len(packet)-optFixedBytes:]

	if opt[0] != 0 {
		// OPT NAME must be 0 (root domain)
		return
	}
	if dns.Type(binary.BigEndian.Uint16(opt[1:3])) != dns.TypeOPT {
		// Not an OPT record
		return
	}
	requestedSize := binary.BigEndian.Uint16(opt[3:5])
	// Ignore extended RCODE in opt[5]
	if opt[6] != edns0Version {
		// Be conservative and don't touch unknown versions.
		return
	}
	// Ignore flags in opt[7:9]
	if binary.BigEndian.Uint16(opt[10:12]) != 0 {
		// RDLEN must be 0 (no variable length data). We're at the end of the
		// packet so this should be 0 anyway)..
		return
	}

	if requestedSize <= maxSize {
		return
	}

	// Clamp the maximum size
	binary.BigEndian.PutUint16(opt[3:5], maxSize)
}

type route struct {
	Suffix    dnsname.FQDN
	Resolvers []netaddr.IPPort
}

// forwarder forwards DNS packets to a number of upstream nameservers.
type forwarder struct {
	logf    logger.Logf
	linkMon *monitor.Mon
	linkSel ForwardLinkSelector

	ctx       context.Context    // good until Close
	ctxCancel context.CancelFunc // closes ctx

	// responses is a channel by which responses are returned.
	responses chan packet

	mu sync.Mutex // guards following

	// routes are per-suffix resolvers to use, with
	// the most specific routes first.
	routes []route
}

func init() {
	rand.Seed(time.Now().UnixNano())
}

func newForwarder(logf logger.Logf, responses chan packet, linkMon *monitor.Mon, linkSel ForwardLinkSelector) *forwarder {
	f := &forwarder{
		logf:      logger.WithPrefix(logf, "forward: "),
		linkMon:   linkMon,
		linkSel:   linkSel,
		responses: responses,
	}
	f.ctx, f.ctxCancel = context.WithCancel(context.Background())
	return f
}

func (f *forwarder) Close() error {
	f.ctxCancel()
	return nil
}

func (f *forwarder) setRoutes(routes []route) {
	f.mu.Lock()
	defer f.mu.Unlock()
	f.routes = routes
}

var stdNetPacketListener packetListener = new(net.ListenConfig)

type packetListener interface {
	ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error)
}

func (f *forwarder) packetListener(ip netaddr.IP) (packetListener, error) {
	if f.linkSel == nil || initListenConfig == nil {
		return stdNetPacketListener, nil
	}
	linkName := f.linkSel.PickLink(ip)
	if linkName == "" {
		return stdNetPacketListener, nil
	}
	lc := new(net.ListenConfig)
	if err := initListenConfig(lc, f.linkMon, linkName); err != nil {
		return nil, err
	}
	return lc, nil
}

// send sends packet to dst. It is best effort.
//
// send expects the reply to have the same txid as txidOut.
//
// The provided closeOnCtxDone lets send register values to Close if
// the caller's ctx expires. This avoids send from allocating its own
// waiting goroutine to interrupt the ReadFrom, as memory is tight on
// iOS and we want the number of pending DNS lookups to be bursty
// without too much associated goroutine/memory cost.
func (f *forwarder) send(ctx context.Context, txidOut txid, closeOnCtxDone *closePool, packet []byte, dst netaddr.IPPort) ([]byte, error) {
	// TODO(bradfitz): if dst.IP is 8.8.8.8 or 8.8.4.4 or 1.1.1.1, etc, or
	// something dynamically probed earlier to support DoH or DoT,
	// do that here instead.

	ln, err := f.packetListener(dst.IP())
	if err != nil {
		return nil, err
	}
	conn, err := ln.ListenPacket(ctx, "udp", ":0")
	if err != nil {
		f.logf("ListenPacket failed: %v", err)
		return nil, err
	}
	defer conn.Close()

	closeOnCtxDone.Add(conn)
	defer closeOnCtxDone.Remove(conn)

	if _, err := conn.WriteTo(packet, dst.UDPAddr()); err != nil {
		if err := ctx.Err(); err != nil {
			return nil, err
		}
		return nil, err
	}

	// The 1 extra byte is to detect packet truncation.
	out := make([]byte, maxResponseBytes+1)
	n, _, err := conn.ReadFrom(out)
	if err != nil {
		if err := ctx.Err(); err != nil {
			return nil, err
		}
		if packetWasTruncated(err) {
			err = nil
		} else {
			return nil, err
		}
	}
	truncated := n > maxResponseBytes
	if truncated {
		n = maxResponseBytes
	}
	if n < headerBytes {
		f.logf("recv: packet too small (%d bytes)", n)
	}
	out = out[:n]
	txid := getTxID(out)
	if txid != txidOut {
		return nil, errors.New("txid doesn't match")
	}

	if truncated {
		const dnsFlagTruncated = 0x200
		flags := binary.BigEndian.Uint16(out[2:4])
		flags |= dnsFlagTruncated
		binary.BigEndian.PutUint16(out[2:4], flags)

		// TODO(#2067): Remove any incomplete records? RFC 1035 section 6.2
		// states that truncation should head drop so that the authority
		// section can be preserved if possible. However, the UDP read with
		// a too-small buffer has already dropped the end, so that's the
		// best we can do.
	}

	clampEDNSSize(out, maxResponseBytes)

	return out, nil
}

// resolvers returns the resolvers to use for domain.
func (f *forwarder) resolvers(domain dnsname.FQDN) []netaddr.IPPort {
	f.mu.Lock()
	routes := f.routes
	f.mu.Unlock()
	for _, route := range routes {
		if route.Suffix == "." || route.Suffix.Contains(domain) {
			return route.Resolvers
		}
	}
	return nil
}

// forward forwards the query to all upstream nameservers and returns the first response.
func (f *forwarder) forward(query packet) error {
	domain, err := nameFromQuery(query.bs)
	if err != nil {
		return err
	}

	txid := getTxID(query.bs)
	clampEDNSSize(query.bs, maxResponseBytes)

	resolvers := f.resolvers(domain)
	if len(resolvers) == 0 {
		return errNoUpstreams
	}

	closeOnCtxDone := new(closePool)
	defer closeOnCtxDone.Close()

	ctx, cancel := context.WithTimeout(f.ctx, responseTimeout)
	defer cancel()

	resc := make(chan []byte, 1)
	var (
		mu       sync.Mutex
		firstErr error
	)

	for _, ipp := range resolvers {
		go func(ipp netaddr.IPPort) {
			resb, err := f.send(ctx, txid, closeOnCtxDone, query.bs, ipp)
			if err != nil {
				mu.Lock()
				defer mu.Unlock()
				if firstErr == nil {
					firstErr = err
				}
				return
			}
			select {
			case resc <- resb:
			default:
			}
		}(ipp)
	}

	select {
	case v := <-resc:
		select {
		case <-ctx.Done():
			return ctx.Err()
		case f.responses <- packet{v, query.addr}:
			return nil
		}
	case <-ctx.Done():
		mu.Lock()
		defer mu.Unlock()
		if firstErr != nil {
			return firstErr
		}
		return ctx.Err()
	}
}

var initListenConfig func(_ *net.ListenConfig, _ *monitor.Mon, tunName string) error

// nameFromQuery extracts the normalized query name from bs.
func nameFromQuery(bs []byte) (dnsname.FQDN, error) {
	var parser dns.Parser

	hdr, err := parser.Start(bs)
	if err != nil {
		return "", err
	}
	if hdr.Response {
		return "", errNotQuery
	}

	q, err := parser.Question()
	if err != nil {
		return "", err
	}

	n := q.Name.Data[:q.Name.Length]
	return dnsname.ToFQDN(rawNameToLower(n))
}

// closePool is a dynamic set of io.Closers to close as a group.
// It's intended to be Closed at most once.
//
// The zero value is ready for use.
type closePool struct {
	mu     sync.Mutex
	m      map[io.Closer]bool
	closed bool
}

func (p *closePool) Add(c io.Closer) {
	p.mu.Lock()
	defer p.mu.Unlock()
	if p.closed {
		c.Close()
		return
	}
	if p.m == nil {
		p.m = map[io.Closer]bool{}
	}
	p.m[c] = true
}

func (p *closePool) Remove(c io.Closer) {
	p.mu.Lock()
	defer p.mu.Unlock()
	if p.closed {
		return
	}
	delete(p.m, c)
}

func (p *closePool) Close() error {
	p.mu.Lock()
	defer p.mu.Unlock()
	if p.closed {
		return nil
	}
	p.closed = true
	for c := range p.m {
		c.Close()
	}
	return nil
}