net/dns: retry forwarder requests over TCP

We weren't correctly retrying truncated requests to an upstream DNS
server with TCP. Instead, we'd return a truncated request to the user,
even if the user was querying us over TCP and thus able to handle a
large response.

Also, add an envknob and controlknob to allow users/us to disable this
behaviour if it turns out to be buggy ( DNS ).

Updates #9264

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: Ifb04b563839a9614c0ba03e9c564e8924c1a2bfd
This commit is contained in:
Andrew Dunham
2023-09-07 16:27:50 -04:00
parent 098d110746
commit 530aaa52f1
13 changed files with 448 additions and 49 deletions

View File

@@ -21,6 +21,7 @@ import (
"time"
dns "golang.org/x/net/dns/dnsmessage"
"tailscale.com/control/controlknobs"
"tailscale.com/envknob"
"tailscale.com/net/dns/publicdns"
"tailscale.com/net/dnscache"
@@ -68,6 +69,10 @@ const (
// DNS queries to the "fallback" DNS server IP for a known provider
// (e.g. how long to wait to query Google's 8.8.4.4 after 8.8.8.8).
wellKnownHostBackupDelay = 200 * time.Millisecond
// tcpQueryTimeout is the timeout for a DNS query performed over TCP.
// It matches the default 5sec timeout of the 'dig' utility.
tcpQueryTimeout = 5 * time.Second
)
// txid identifies a DNS transaction.
@@ -180,6 +185,8 @@ type forwarder struct {
linkSel ForwardLinkSelector // TODO(bradfitz): remove this when tsdial.Dialer absorbs it
dialer *tsdial.Dialer
controlKnobs *controlknobs.Knobs // or nil
ctx context.Context // good until Close
ctxCancel context.CancelFunc // closes ctx
@@ -206,12 +213,13 @@ func init() {
rand.Seed(time.Now().UnixNano())
}
func newForwarder(logf logger.Logf, netMon *netmon.Monitor, linkSel ForwardLinkSelector, dialer *tsdial.Dialer) *forwarder {
func newForwarder(logf logger.Logf, netMon *netmon.Monitor, linkSel ForwardLinkSelector, dialer *tsdial.Dialer, knobs *controlknobs.Knobs) *forwarder {
f := &forwarder{
logf: logger.WithPrefix(logf, "forward: "),
netMon: netMon,
linkSel: linkSel,
dialer: dialer,
logf: logger.WithPrefix(logf, "forward: "),
netMon: netMon,
linkSel: linkSel,
dialer: dialer,
controlKnobs: knobs,
}
f.ctx, f.ctxCancel = context.WithCancel(context.Background())
return f
@@ -443,7 +451,10 @@ func (f *forwarder) sendDoH(ctx context.Context, urlBase string, c *http.Client,
return res, err
}
var verboseDNSForward = envknob.RegisterBool("TS_DEBUG_DNS_FORWARD_SEND")
var (
verboseDNSForward = envknob.RegisterBool("TS_DEBUG_DNS_FORWARD_SEND")
skipTCPRetry = envknob.RegisterBool("TS_DNS_FORWARD_SKIP_TCP_RETRY")
)
// send sends packet to dst. It is best effort.
//
@@ -477,10 +488,49 @@ func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDe
return nil, fmt.Errorf("tls:// resolvers not supported yet")
}
return f.sendUDP(ctx, fq, rr)
ret, err = f.sendUDP(ctx, fq, rr)
if err != nil {
return nil, err
}
if !truncatedFlagSet(ret) {
// Successful, non-truncated response; return it.
return ret, nil
}
if fq.family == "udp" {
// If this is a UDP query, return it regardless of whether the
// response is truncated or not; the client can retry
// communicating with tailscaled over TCP. There's no point
// falling back to TCP for a truncated query if we can't return
// the results to the client.
return ret, nil
}
if skipTCPRetry() || (f.controlKnobs != nil && f.controlKnobs.DisableDNSForwarderTCPRetries.Load()) {
// Envknob or control knob disabled the TCP retry behaviour;
// just return what we have.
return ret, nil
}
// Don't retry if our context is done.
if err := ctx.Err(); err != nil {
return nil, err
}
// Retry over TCP, best-effort; return the truncated UDP response if we
// cannot query via TCP.
if ret2, err2 := f.sendTCP(ctx, fq, rr); err2 == nil {
if verboseDNSForward() {
f.logf("forwarder.send(%q): successfully retried via TCP", rr.name.Addr)
}
return ret2, nil
} else if verboseDNSForward() {
f.logf("forwarder.send(%q): could not retry via TCP: %v", rr.name.Addr, err2)
}
return ret, nil
}
var errServerFailure = errors.New("response code indicates server issue")
var errTxIDMismatch = errors.New("txid doesn't match")
func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAndDelay) (ret []byte, err error) {
ipp, ok := rr.name.IPPort()
@@ -545,7 +595,7 @@ func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAn
txid := getTxID(out)
if txid != fq.txid {
metricDNSFwdUDPErrorTxID.Add(1)
return nil, errors.New("txid doesn't match")
return nil, errTxIDMismatch
}
rcode := getRCode(out)
// don't forward transient errors back to the client when the server fails
@@ -577,6 +627,92 @@ func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAn
return out, nil
}
func (f *forwarder) sendTCP(ctx context.Context, fq *forwardQuery, rr resolverAndDelay) (ret []byte, err error) {
ipp, ok := rr.name.IPPort()
if !ok {
metricDNSFwdErrorType.Add(1)
return nil, fmt.Errorf("unrecognized resolver type %q", rr.name.Addr)
}
metricDNSFwdTCP.Add(1)
ctx = sockstats.WithSockStats(ctx, sockstats.LabelDNSForwarderTCP, f.logf)
// Specify the exact family to work around https://github.com/golang/go/issues/52264
tcpFam := "tcp4"
if ipp.Addr().Is6() {
tcpFam = "tcp6"
}
ctx, cancel := context.WithTimeout(ctx, tcpQueryTimeout)
defer cancel()
conn, err := f.dialer.SystemDial(ctx, tcpFam, ipp.String())
if err != nil {
return nil, err
}
defer conn.Close()
fq.closeOnCtxDone.Add(conn)
defer fq.closeOnCtxDone.Remove(conn)
ctxOrErr := func(err2 error) ([]byte, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
return nil, err2
}
// Write the query to the server.
query := make([]byte, len(fq.packet)+2)
binary.BigEndian.PutUint16(query, uint16(len(fq.packet)))
copy(query[2:], fq.packet)
if _, err := conn.Write(query); err != nil {
metricDNSFwdTCPErrorWrite.Add(1)
return ctxOrErr(err)
}
metricDNSFwdTCPWrote.Add(1)
// Read the header length back from the server
var length uint16
if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
metricDNSFwdTCPErrorRead.Add(1)
return ctxOrErr(err)
}
// Now read the response
out := make([]byte, length)
n, err := io.ReadFull(conn, out)
if err != nil {
metricDNSFwdTCPErrorRead.Add(1)
return ctxOrErr(err)
}
if n < int(length) {
f.logf("sendTCP: packet too small (%d bytes)", n)
return nil, io.ErrUnexpectedEOF
}
out = out[:n]
txid := getTxID(out)
if txid != fq.txid {
metricDNSFwdTCPErrorTxID.Add(1)
return nil, errTxIDMismatch
}
rcode := getRCode(out)
// don't forward transient errors back to the client when the server fails
if rcode == dns.RCodeServerFailure {
f.logf("sendTCP: response code indicating server failure: %d", rcode)
metricDNSFwdTCPErrorServer.Add(1)
return nil, errServerFailure
}
// TODO(andrew): do we need to do this?
//clampEDNSSize(out, maxResponseBytes)
metricDNSFwdTCPSuccess.Add(1)
return out, nil
}
// resolvers returns the resolvers to use for domain.
func (f *forwarder) resolvers(domain dnsname.FQDN) []resolverAndDelay {
f.mu.Lock()
@@ -601,6 +737,7 @@ func (f *forwarder) resolvers(domain dnsname.FQDN) []resolverAndDelay {
type forwardQuery struct {
txid txid
packet []byte
family string // "tcp" or "udp"
// closeOnCtxDone lets send register values to Close if the
// caller's ctx expires. This avoids send from allocating its
@@ -686,6 +823,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo
fq := &forwardQuery{
txid: getTxID(query.bs),
packet: query.bs,
family: query.family,
closeOnCtxDone: new(closePool),
}
defer fq.closeOnCtxDone.Close()
@@ -727,7 +865,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo
case <-ctx.Done():
metricDNSFwdErrorContext.Add(1)
return ctx.Err()
case responseChan <- packet{v, query.addr}:
case responseChan <- packet{v, query.family, query.addr}:
metricDNSFwdSuccess.Add(1)
return nil
}