net/dns/resolver, ipn/ipnlocal: wire up peerapi DoH server to DNS forwarder

Updates #1713

Change-Id: Ia4ed9d8c9cef0e70aa6d30f2852eaab80f5f695a
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick
2021-11-23 09:58:34 -08:00
committed by Brad Fitzpatrick
parent 9bb91cb977
commit 25525b7754
3 changed files with 196 additions and 7 deletions

View File

@@ -29,6 +29,7 @@ import (
"unicode"
"unicode/utf8"
"golang.org/x/net/dns/dnsmessage"
"inet.af/netaddr"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/hostinfo"
@@ -767,6 +768,8 @@ func (h *peerAPIHandler) replyToDNSQueries() bool {
return h.isSelf || h.ps.b.OfferingExitNode()
}
// handleDNSQuery implements a DoH server (RFC 8484) over the peerapi.
// It's not over HTTPS as the spec dictates, but rather HTTP-over-WireGuard.
func (h *peerAPIHandler) handleDNSQuery(w http.ResponseWriter, r *http.Request) {
if h.ps.resolver == nil {
http.Error(w, "DNS not wired up", http.StatusNotImplemented)
@@ -776,13 +779,45 @@ func (h *peerAPIHandler) handleDNSQuery(w http.ResponseWriter, r *http.Request)
http.Error(w, "DNS access denied", http.StatusForbidden)
return
}
pretty := false // non-DoH debug mode for humans
q, publicError := dohQuery(r)
if publicError != "" && r.Method == "GET" {
if name := r.FormValue("q"); name != "" {
pretty = true
publicError = ""
q = dnsQueryForName(name, r.FormValue("t"))
}
}
if publicError != "" {
http.Error(w, publicError, http.StatusBadRequest)
return
}
// TODO(bradfitz): owl.
fmt.Fprintf(w, "## TODO: got %d bytes of DNS query", len(q))
// Some timeout that's short enough to be noticed by humans
// but long enough that it's longer than real DNS timeouts.
const arbitraryTimeout = 5 * time.Second
ctx, cancel := context.WithTimeout(r.Context(), arbitraryTimeout)
defer cancel()
res, err := h.ps.resolver.HandleExitNodeDNSQuery(ctx, q, h.remoteAddr)
if err != nil {
h.logf("handleDNS fwd error: %v", err)
if err := ctx.Err(); err != nil {
http.Error(w, err.Error(), 500)
} else {
http.Error(w, "DNS forwarding error", 500)
}
return
}
if pretty {
// Non-standard response for interactive debugging.
w.Header().Set("Content-Type", "application/json")
writePrettyDNSReply(w, res)
return
}
w.Header().Set("Content-Type", "application/dns-message")
w.Header().Set("Content-Length", strconv.Itoa(len(q)))
w.Write(res)
}
func dohQuery(r *http.Request) (dnsQuery []byte, publicErr string) {
@@ -817,3 +852,86 @@ func dohQuery(r *http.Request) (dnsQuery []byte, publicErr string) {
return q, ""
}
}
func dnsQueryForName(name, typStr string) []byte {
typ := dnsmessage.TypeA
switch strings.ToLower(typStr) {
case "aaaa":
typ = dnsmessage.TypeAAAA
case "txt":
typ = dnsmessage.TypeTXT
}
b := dnsmessage.NewBuilder(nil, dnsmessage.Header{
OpCode: 0, // query
RecursionDesired: true,
ID: 0,
})
if !strings.HasSuffix(name, ".") {
name += "."
}
b.StartQuestions()
b.Question(dnsmessage.Question{
Name: dnsmessage.MustNewName(name),
Type: typ,
Class: dnsmessage.ClassINET,
})
msg, _ := b.Finish()
return msg
}
func writePrettyDNSReply(w io.Writer, res []byte) (err error) {
defer func() {
if err != nil {
j, _ := json.Marshal(struct {
Error string
}{err.Error()})
w.Write(j)
return
}
}()
var p dnsmessage.Parser
if _, err := p.Start(res); err != nil {
return err
}
if err := p.SkipAllQuestions(); err != nil {
return err
}
var gotIPs []string
for {
h, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
return err
}
if h.Class != dnsmessage.ClassINET {
continue
}
switch h.Type {
case dnsmessage.TypeA:
r, err := p.AResource()
if err != nil {
return err
}
gotIPs = append(gotIPs, net.IP(r.A[:]).String())
case dnsmessage.TypeAAAA:
r, err := p.AAAAResource()
if err != nil {
return err
}
gotIPs = append(gotIPs, net.IP(r.AAAA[:]).String())
case dnsmessage.TypeTXT:
r, err := p.TXTResource()
if err != nil {
return err
}
gotIPs = append(gotIPs, r.TXT...)
}
}
j, _ := json.Marshal(gotIPs)
j = append(j, '\n')
w.Write(j)
return nil
}