// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

package resolver

import (
	"fmt"
	"net"
	"net/netip"
	"strings"
	"testing"

	"github.com/miekg/dns"
)

// This file exists to isolate the test infrastructure
// that depends on github.com/miekg/dns
// from the rest, which only depends on dnsmessage.

// resolveToIP returns a handler function which responds
// to queries of type A it receives with an A record containing ipv4,
// to queries of type AAAA with an AAAA record containing ipv6,
// to queries of type NS with an NS record containing name.
func resolveToIP(ipv4, ipv6 netip.Addr, ns string) dns.HandlerFunc {
	return func(w dns.ResponseWriter, req *dns.Msg) {
		m := new(dns.Msg)
		m.SetReply(req)

		if len(req.Question) != 1 {
			panic("not a single-question request")
		}
		question := req.Question[0]

		var ans dns.RR
		switch question.Qtype {
		case dns.TypeA:
			ans = &dns.A{
				Hdr: dns.RR_Header{
					Name:   question.Name,
					Rrtype: dns.TypeA,
					Class:  dns.ClassINET,
				},
				A: ipv4.AsSlice(),
			}
		case dns.TypeAAAA:
			ans = &dns.AAAA{
				Hdr: dns.RR_Header{
					Name:   question.Name,
					Rrtype: dns.TypeAAAA,
					Class:  dns.ClassINET,
				},
				AAAA: ipv6.AsSlice(),
			}
		case dns.TypeNS:
			ans = &dns.NS{
				Hdr: dns.RR_Header{
					Name:   question.Name,
					Rrtype: dns.TypeNS,
					Class:  dns.ClassINET,
				},
				Ns: ns,
			}
		}

		m.Answer = append(m.Answer, ans)
		w.WriteMsg(m)
	}
}

// resolveToIPLowercase returns a handler function which canonicalizes responses
// by lowercasing the question and answer names, and responds
// to queries of type A it receives with an A record containing ipv4,
// to queries of type AAAA with an AAAA record containing ipv6,
// to queries of type NS with an NS record containing name.
func resolveToIPLowercase(ipv4, ipv6 netip.Addr, ns string) dns.HandlerFunc {
	return func(w dns.ResponseWriter, req *dns.Msg) {
		m := new(dns.Msg)
		m.SetReply(req)

		if len(req.Question) != 1 {
			panic("not a single-question request")
		}
		m.Question[0].Name = strings.ToLower(m.Question[0].Name)
		question := req.Question[0]

		var ans dns.RR
		switch question.Qtype {
		case dns.TypeA:
			ans = &dns.A{
				Hdr: dns.RR_Header{
					Name:   question.Name,
					Rrtype: dns.TypeA,
					Class:  dns.ClassINET,
				},
				A: ipv4.AsSlice(),
			}
		case dns.TypeAAAA:
			ans = &dns.AAAA{
				Hdr: dns.RR_Header{
					Name:   question.Name,
					Rrtype: dns.TypeAAAA,
					Class:  dns.ClassINET,
				},
				AAAA: ipv6.AsSlice(),
			}
		case dns.TypeNS:
			ans = &dns.NS{
				Hdr: dns.RR_Header{
					Name:   question.Name,
					Rrtype: dns.TypeNS,
					Class:  dns.ClassINET,
				},
				Ns: ns,
			}
		}

		m.Answer = append(m.Answer, ans)
		w.WriteMsg(m)
	}
}

// resolveToTXT returns a handler function which responds to queries of type TXT
// it receives with the strings in txts.
func resolveToTXT(txts []string, ednsMaxSize uint16) dns.HandlerFunc {
	return func(w dns.ResponseWriter, req *dns.Msg) {
		m := new(dns.Msg)
		m.SetReply(req)

		if len(req.Question) != 1 {
			panic("not a single-question request")
		}
		question := req.Question[0]

		if question.Qtype != dns.TypeTXT {
			w.WriteMsg(m)
			return
		}

		ans := &dns.TXT{
			Hdr: dns.RR_Header{
				Name:   question.Name,
				Rrtype: dns.TypeTXT,
				Class:  dns.ClassINET,
			},
			Txt: txts,
		}

		m.Answer = append(m.Answer, ans)

		queryInfo := &dns.TXT{
			Hdr: dns.RR_Header{
				Name:   "query-info.test.",
				Rrtype: dns.TypeTXT,
				Class:  dns.ClassINET,
			},
		}

		if edns := req.IsEdns0(); edns == nil {
			queryInfo.Txt = []string{"EDNS=false"}
		} else {
			queryInfo.Txt = []string{"EDNS=true", fmt.Sprintf("maxSize=%v", edns.UDPSize())}
		}

		m.Extra = append(m.Extra, queryInfo)

		if ednsMaxSize > 0 {
			m.SetEdns0(ednsMaxSize, false)
		}

		if err := w.WriteMsg(m); err != nil {
			panic(err)
		}
	}
}

var resolveToNXDOMAIN = dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
	m := new(dns.Msg)
	m.SetRcode(req, dns.RcodeNameError)
	w.WriteMsg(m)
})

// weirdoGoCNAMEHandler returns a DNS handler that satisfies
// Go's weird Resolver.LookupCNAME (read its godoc carefully!).
//
// This doesn't even return a CNAME record, because that's not
// what Go looks for.
func weirdoGoCNAMEHandler(target string) dns.HandlerFunc {
	return func(w dns.ResponseWriter, req *dns.Msg) {
		m := new(dns.Msg)
		m.SetReply(req)
		question := req.Question[0]

		switch question.Qtype {
		case dns.TypeA:
			m.Answer = append(m.Answer, &dns.CNAME{
				Hdr: dns.RR_Header{
					Name:   target,
					Rrtype: dns.TypeCNAME,
					Class:  dns.ClassINET,
					Ttl:    600,
				},
				Target: target,
			})
		case dns.TypeAAAA:
			m.Answer = append(m.Answer, &dns.AAAA{
				Hdr: dns.RR_Header{
					Name:   target,
					Rrtype: dns.TypeAAAA,
					Class:  dns.ClassINET,
					Ttl:    600,
				},
				AAAA: net.ParseIP("1::2"),
			})
		}
		w.WriteMsg(m)
	}
}

// dnsHandler returns a handler that replies with the answers/options
// provided.
//
// Types supported: netip.Addr.
func dnsHandler(answers ...any) dns.HandlerFunc {
	return func(w dns.ResponseWriter, req *dns.Msg) {
		m := new(dns.Msg)
		m.SetReply(req)
		if len(req.Question) != 1 {
			panic("not a single-question request")
		}
		m.RecursionAvailable = true // to stop net package's errLameReferral on empty replies

		question := req.Question[0]
		for _, a := range answers {
			switch a := a.(type) {
			default:
				panic(fmt.Sprintf("unsupported dnsHandler arg %T", a))
			case netip.Addr:
				ip := a
				if ip.Is4() {
					m.Answer = append(m.Answer, &dns.A{
						Hdr: dns.RR_Header{
							Name:   question.Name,
							Rrtype: dns.TypeA,
							Class:  dns.ClassINET,
						},
						A: ip.AsSlice(),
					})
				} else if ip.Is6() {
					m.Answer = append(m.Answer, &dns.AAAA{
						Hdr: dns.RR_Header{
							Name:   question.Name,
							Rrtype: dns.TypeAAAA,
							Class:  dns.ClassINET,
						},
						AAAA: ip.AsSlice(),
					})
				}
			case dns.PTR:
				ptr := a
				ptr.Hdr = dns.RR_Header{
					Name:   question.Name,
					Rrtype: dns.TypePTR,
					Class:  dns.ClassINET,
				}
				m.Answer = append(m.Answer, &ptr)
			case dns.CNAME:
				c := a
				c.Hdr = dns.RR_Header{
					Name:   question.Name,
					Rrtype: dns.TypeCNAME,
					Class:  dns.ClassINET,
					Ttl:    600,
				}
				m.Answer = append(m.Answer, &c)
			case dns.TXT:
				txt := a
				txt.Hdr = dns.RR_Header{
					Name:   question.Name,
					Rrtype: dns.TypeTXT,
					Class:  dns.ClassINET,
				}
				m.Answer = append(m.Answer, &txt)
			case dns.SRV:
				srv := a
				srv.Hdr = dns.RR_Header{
					Name:   question.Name,
					Rrtype: dns.TypeSRV,
					Class:  dns.ClassINET,
				}
				m.Answer = append(m.Answer, &srv)
			case dns.NS:
				rr := a
				rr.Hdr = dns.RR_Header{
					Name:   question.Name,
					Rrtype: dns.TypeNS,
					Class:  dns.ClassINET,
				}
				m.Answer = append(m.Answer, &rr)
			}
		}
		w.WriteMsg(m)
	}
}

func serveDNS(tb testing.TB, addr string, records ...any) *dns.Server {
	if len(records)%2 != 0 {
		panic("must have an even number of record values")
	}
	mux := dns.NewServeMux()
	for i := 0; i < len(records); i += 2 {
		name := records[i].(string)
		handler := records[i+1].(dns.Handler)
		mux.Handle(name, handler)
	}
	waitch := make(chan struct{})
	server := &dns.Server{
		Addr:              addr,
		Net:               "udp",
		Handler:           mux,
		NotifyStartedFunc: func() { close(waitch) },
		ReusePort:         true,
	}

	go func() {
		err := server.ListenAndServe()
		if err != nil {
			panic(fmt.Sprintf("ListenAndServe(%q): %v", addr, err))
		}
	}()

	<-waitch
	return server
}