net/dns: retry forwarder requests over TCP

We weren't correctly retrying truncated requests to an upstream DNS
server with TCP. Instead, we'd return a truncated request to the user,
even if the user was querying us over TCP and thus able to handle a
large response.

Also, add an envknob and controlknob to allow users/us to disable this
behaviour if it turns out to be buggy ( DNS ).

Updates #9264

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: Ifb04b563839a9614c0ba03e9c564e8924c1a2bfd
This commit is contained in:
Andrew Dunham
2023-09-07 16:27:50 -04:00
parent 098d110746
commit 530aaa52f1
13 changed files with 448 additions and 49 deletions

View File

@@ -4,14 +4,26 @@
package resolver
import (
"bytes"
"context"
"encoding/binary"
"flag"
"fmt"
"io"
"net"
"net/netip"
"os"
"reflect"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
dns "golang.org/x/net/dns/dnsmessage"
"tailscale.com/envknob"
"tailscale.com/net/netmon"
"tailscale.com/net/tsdial"
"tailscale.com/types/dnstype"
)
@@ -240,3 +252,224 @@ func FuzzClampEDNSSize(f *testing.F) {
clampEDNSSize(data, maxResponseBytes)
})
}
func runDNSServer(tb testing.TB, response []byte, onRequest func(bool, []byte)) (port uint16) {
tcpResponse := make([]byte, len(response)+2)
binary.BigEndian.PutUint16(tcpResponse, uint16(len(response)))
copy(tcpResponse[2:], response)
// Repeatedly listen until we can get the same port.
const tries = 25
var (
tcpLn *net.TCPListener
udpLn *net.UDPConn
err error
)
for try := 0; try < tries; try++ {
if tcpLn != nil {
tcpLn.Close()
tcpLn = nil
}
tcpLn, err = net.ListenTCP("tcp4", &net.TCPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 0, // Choose one
})
if err != nil {
tb.Fatal(err)
}
udpLn, err = net.ListenUDP("udp4", &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: tcpLn.Addr().(*net.TCPAddr).Port,
})
if err == nil {
break
}
}
if tcpLn == nil || udpLn == nil {
if tcpLn != nil {
tcpLn.Close()
}
if udpLn != nil {
udpLn.Close()
}
// Skip instead of being fatal to avoid flaking on extremely
// heavily-loaded CI systems.
tb.Skipf("failed to listen on same port for TCP/UDP after %d tries", tries)
}
port = uint16(tcpLn.Addr().(*net.TCPAddr).Port)
handleConn := func(conn net.Conn) {
defer conn.Close()
// Read the length header, then the buffer
var length uint16
if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
tb.Logf("error reading length header: %v", err)
return
}
req := make([]byte, length)
n, err := io.ReadFull(conn, req)
if err != nil {
tb.Logf("error reading query: %v", err)
return
}
req = req[:n]
onRequest(true, req)
// Write response
if _, err := conn.Write(tcpResponse); err != nil {
tb.Logf("error writing response: %v", err)
return
}
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
for {
conn, err := tcpLn.Accept()
if err != nil {
return
}
go handleConn(conn)
}
}()
handleUDP := func(addr netip.AddrPort, req []byte) {
onRequest(false, req)
if _, err := udpLn.WriteToUDPAddrPort(response, addr); err != nil {
tb.Logf("error writing response: %v", err)
}
}
wg.Add(1)
go func() {
defer wg.Done()
for {
buf := make([]byte, 65535)
n, addr, err := udpLn.ReadFromUDPAddrPort(buf)
if err != nil {
return
}
buf = buf[:n]
go handleUDP(addr, buf)
}
}()
tb.Cleanup(func() {
tcpLn.Close()
udpLn.Close()
tb.Logf("waiting for listeners to finish...")
wg.Wait()
})
return
}
func TestForwarderTCPFallback(t *testing.T) {
const debugKnob = "TS_DEBUG_DNS_FORWARD_SEND"
oldVal := os.Getenv(debugKnob)
envknob.Setenv(debugKnob, "true")
t.Cleanup(func() { envknob.Setenv(debugKnob, oldVal) })
const domain = "large-dns-response.tailscale.com."
// Make a response that's very large, containing a bunch of localhost addresses.
largeResponse := func() []byte {
name := dns.MustNewName(domain)
builder := dns.NewBuilder(nil, dns.Header{})
builder.StartQuestions()
builder.Question(dns.Question{
Name: name,
Type: dns.TypeA,
Class: dns.ClassINET,
})
builder.StartAnswers()
for i := 0; i < 120; i++ {
builder.AResource(dns.ResourceHeader{
Name: name,
Class: dns.ClassINET,
TTL: 300,
}, dns.AResource{
A: [4]byte{127, 0, 0, byte(i)},
})
}
msg, err := builder.Finish()
if err != nil {
t.Fatal(err)
}
return msg
}()
if len(largeResponse) <= maxResponseBytes {
t.Fatalf("got len(largeResponse)=%d, want > %d", len(largeResponse), maxResponseBytes)
}
// Our request is a single A query for the domain in the answer, above.
request := func() []byte {
builder := dns.NewBuilder(nil, dns.Header{})
builder.StartQuestions()
builder.Question(dns.Question{
Name: dns.MustNewName(domain),
Type: dns.TypeA,
Class: dns.ClassINET,
})
msg, err := builder.Finish()
if err != nil {
t.Fatal(err)
}
return msg
}()
var sawUDPRequest, sawTCPRequest atomic.Bool
port := runDNSServer(t, largeResponse, func(isTCP bool, gotRequest []byte) {
if isTCP {
sawTCPRequest.Store(true)
} else {
sawUDPRequest.Store(true)
}
if !bytes.Equal(request, gotRequest) {
t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request)
}
})
netMon, err := netmon.New(t.Logf)
if err != nil {
t.Fatal(err)
}
var dialer tsdial.Dialer
dialer.SetNetMon(netMon)
fwd := newForwarder(t.Logf, netMon, nil, &dialer, nil)
fq := &forwardQuery{
txid: getTxID(request),
packet: request,
closeOnCtxDone: new(closePool),
}
defer fq.closeOnCtxDone.Close()
rr := resolverAndDelay{
name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)},
}
resp, err := fwd.send(context.Background(), fq, rr)
if err != nil {
t.Fatalf("error making request: %v", err)
}
if !bytes.Equal(resp, largeResponse) {
t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
}
if !sawTCPRequest.Load() {
t.Errorf("DNS server never saw TCP request")
}
if !sawUDPRequest.Load() {
t.Errorf("DNS server never saw UDP request")
}
}