tailscale/net/tsdial/dohclient.go
Brad Fitzpatrick 39ffa16853 net/dnscache, net/tsdial: add DNS caching to tsdial UserDial
This is enough to handle the DNS queries as generated by Go's
net package (which our HTTP/SOCKS client uses), and the responses
generated by the ExitDNS DoH server.

This isn't yet suitable for putting on 100.100.100.100 where a number
of different DNS clients would hit it, as this doesn't yet do
EDNS0. It might work, but it's untested and likely incomplete.

Likewise, this doesn't handle anything about truncation, as the
exchanges are entirely in memory between Go or DoH. That would also
need to be handled later, if/when it's hooked up to 100.100.100.100.

Updates #3507

Change-Id: I1736b0ad31eea85ea853b310c52c5e6bf65c6e2a
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
2021-12-09 11:34:21 -08:00

102 lines
2.5 KiB
Go

// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tsdial
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"time"
"tailscale.com/net/dnscache"
)
// dohConn is a net.PacketConn suitable for returning from
// net.Dialer.Dial to send DNS queries over PeerAPI to exit nodes'
// ExitDNS DoH proxy service.
type dohConn struct {
ctx context.Context
baseURL string
hc *http.Client // if nil, default is used
dnsCache *dnscache.MessageCache
rbuf bytes.Buffer
}
var (
_ net.Conn = (*dohConn)(nil)
_ net.PacketConn = (*dohConn)(nil) // be a PacketConn to change net.Resolver semantics
)
func (*dohConn) Close() error { return nil }
func (*dohConn) LocalAddr() net.Addr { return todoAddr{} }
func (*dohConn) RemoteAddr() net.Addr { return todoAddr{} }
func (*dohConn) SetDeadline(t time.Time) error { return nil }
func (*dohConn) SetReadDeadline(t time.Time) error { return nil }
func (*dohConn) SetWriteDeadline(t time.Time) error { return nil }
func (c *dohConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return c.Write(p)
}
func (c *dohConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, err = c.Read(p)
return n, todoAddr{}, err
}
func (c *dohConn) Read(p []byte) (n int, err error) {
return c.rbuf.Read(p)
}
func (c *dohConn) Write(packet []byte) (n int, err error) {
if c.dnsCache != nil {
err := c.dnsCache.ReplyFromCache(&c.rbuf, packet)
if err == nil {
// Cache hit.
// TODO(bradfitz): add clientmetric
return len(packet), nil
}
c.rbuf.Reset()
}
req, err := http.NewRequestWithContext(c.ctx, "POST", c.baseURL, bytes.NewReader(packet))
if err != nil {
return 0, err
}
const dohType = "application/dns-message"
req.Header.Set("Content-Type", dohType)
hc := c.hc
if hc == nil {
hc = http.DefaultClient
}
hres, err := hc.Do(req)
if err != nil {
return 0, err
}
defer hres.Body.Close()
if hres.StatusCode != 200 {
return 0, errors.New(hres.Status)
}
if ct := hres.Header.Get("Content-Type"); ct != dohType {
return 0, fmt.Errorf("unexpected response Content-Type %q", ct)
}
_, err = io.Copy(&c.rbuf, hres.Body)
if err != nil {
return 0, err
}
if c.dnsCache != nil {
c.dnsCache.AddCacheEntry(packet, c.rbuf.Bytes())
}
return len(packet), nil
}
type todoAddr struct{}
func (todoAddr) Network() string { return "unused" }
func (todoAddr) String() string { return "unused-todoAddr" }