mirror of
https://github.com/tailscale/tailscale.git
synced 2025-08-25 20:57:31 +00:00
net/dns: retry forwarder requests over TCP
We weren't correctly retrying truncated requests to an upstream DNS server with TCP. Instead, we'd return a truncated request to the user, even if the user was querying us over TCP and thus able to handle a large response. Also, add an envknob and controlknob to allow users/us to disable this behaviour if it turns out to be buggy (✨ DNS ✨). Updates #9264 Signed-off-by: Andrew Dunham <andrew@du.nham.ca> Change-Id: Ifb04b563839a9614c0ba03e9c564e8924c1a2bfd
This commit is contained in:
@@ -4,14 +4,26 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dns "golang.org/x/net/dns/dnsmessage"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/net/netmon"
|
||||
"tailscale.com/net/tsdial"
|
||||
"tailscale.com/types/dnstype"
|
||||
)
|
||||
|
||||
@@ -240,3 +252,224 @@ func FuzzClampEDNSSize(f *testing.F) {
|
||||
clampEDNSSize(data, maxResponseBytes)
|
||||
})
|
||||
}
|
||||
|
||||
func runDNSServer(tb testing.TB, response []byte, onRequest func(bool, []byte)) (port uint16) {
|
||||
tcpResponse := make([]byte, len(response)+2)
|
||||
binary.BigEndian.PutUint16(tcpResponse, uint16(len(response)))
|
||||
copy(tcpResponse[2:], response)
|
||||
|
||||
// Repeatedly listen until we can get the same port.
|
||||
const tries = 25
|
||||
var (
|
||||
tcpLn *net.TCPListener
|
||||
udpLn *net.UDPConn
|
||||
err error
|
||||
)
|
||||
for try := 0; try < tries; try++ {
|
||||
if tcpLn != nil {
|
||||
tcpLn.Close()
|
||||
tcpLn = nil
|
||||
}
|
||||
|
||||
tcpLn, err = net.ListenTCP("tcp4", &net.TCPAddr{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Port: 0, // Choose one
|
||||
})
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
udpLn, err = net.ListenUDP("udp4", &net.UDPAddr{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Port: tcpLn.Addr().(*net.TCPAddr).Port,
|
||||
})
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if tcpLn == nil || udpLn == nil {
|
||||
if tcpLn != nil {
|
||||
tcpLn.Close()
|
||||
}
|
||||
if udpLn != nil {
|
||||
udpLn.Close()
|
||||
}
|
||||
|
||||
// Skip instead of being fatal to avoid flaking on extremely
|
||||
// heavily-loaded CI systems.
|
||||
tb.Skipf("failed to listen on same port for TCP/UDP after %d tries", tries)
|
||||
}
|
||||
|
||||
port = uint16(tcpLn.Addr().(*net.TCPAddr).Port)
|
||||
|
||||
handleConn := func(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
// Read the length header, then the buffer
|
||||
var length uint16
|
||||
if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
|
||||
tb.Logf("error reading length header: %v", err)
|
||||
return
|
||||
}
|
||||
req := make([]byte, length)
|
||||
n, err := io.ReadFull(conn, req)
|
||||
if err != nil {
|
||||
tb.Logf("error reading query: %v", err)
|
||||
return
|
||||
}
|
||||
req = req[:n]
|
||||
onRequest(true, req)
|
||||
|
||||
// Write response
|
||||
if _, err := conn.Write(tcpResponse); err != nil {
|
||||
tb.Logf("error writing response: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
conn, err := tcpLn.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go handleConn(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
handleUDP := func(addr netip.AddrPort, req []byte) {
|
||||
onRequest(false, req)
|
||||
if _, err := udpLn.WriteToUDPAddrPort(response, addr); err != nil {
|
||||
tb.Logf("error writing response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
buf := make([]byte, 65535)
|
||||
n, addr, err := udpLn.ReadFromUDPAddrPort(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
buf = buf[:n]
|
||||
go handleUDP(addr, buf)
|
||||
}
|
||||
}()
|
||||
|
||||
tb.Cleanup(func() {
|
||||
tcpLn.Close()
|
||||
udpLn.Close()
|
||||
tb.Logf("waiting for listeners to finish...")
|
||||
wg.Wait()
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func TestForwarderTCPFallback(t *testing.T) {
|
||||
const debugKnob = "TS_DEBUG_DNS_FORWARD_SEND"
|
||||
oldVal := os.Getenv(debugKnob)
|
||||
envknob.Setenv(debugKnob, "true")
|
||||
t.Cleanup(func() { envknob.Setenv(debugKnob, oldVal) })
|
||||
|
||||
const domain = "large-dns-response.tailscale.com."
|
||||
|
||||
// Make a response that's very large, containing a bunch of localhost addresses.
|
||||
largeResponse := func() []byte {
|
||||
name := dns.MustNewName(domain)
|
||||
|
||||
builder := dns.NewBuilder(nil, dns.Header{})
|
||||
builder.StartQuestions()
|
||||
builder.Question(dns.Question{
|
||||
Name: name,
|
||||
Type: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
})
|
||||
builder.StartAnswers()
|
||||
for i := 0; i < 120; i++ {
|
||||
builder.AResource(dns.ResourceHeader{
|
||||
Name: name,
|
||||
Class: dns.ClassINET,
|
||||
TTL: 300,
|
||||
}, dns.AResource{
|
||||
A: [4]byte{127, 0, 0, byte(i)},
|
||||
})
|
||||
}
|
||||
|
||||
msg, err := builder.Finish()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return msg
|
||||
}()
|
||||
if len(largeResponse) <= maxResponseBytes {
|
||||
t.Fatalf("got len(largeResponse)=%d, want > %d", len(largeResponse), maxResponseBytes)
|
||||
}
|
||||
|
||||
// Our request is a single A query for the domain in the answer, above.
|
||||
request := func() []byte {
|
||||
builder := dns.NewBuilder(nil, dns.Header{})
|
||||
builder.StartQuestions()
|
||||
builder.Question(dns.Question{
|
||||
Name: dns.MustNewName(domain),
|
||||
Type: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
})
|
||||
msg, err := builder.Finish()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return msg
|
||||
}()
|
||||
|
||||
var sawUDPRequest, sawTCPRequest atomic.Bool
|
||||
port := runDNSServer(t, largeResponse, func(isTCP bool, gotRequest []byte) {
|
||||
if isTCP {
|
||||
sawTCPRequest.Store(true)
|
||||
} else {
|
||||
sawUDPRequest.Store(true)
|
||||
}
|
||||
|
||||
if !bytes.Equal(request, gotRequest) {
|
||||
t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request)
|
||||
}
|
||||
})
|
||||
|
||||
netMon, err := netmon.New(t.Logf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var dialer tsdial.Dialer
|
||||
dialer.SetNetMon(netMon)
|
||||
|
||||
fwd := newForwarder(t.Logf, netMon, nil, &dialer, nil)
|
||||
|
||||
fq := &forwardQuery{
|
||||
txid: getTxID(request),
|
||||
packet: request,
|
||||
closeOnCtxDone: new(closePool),
|
||||
}
|
||||
defer fq.closeOnCtxDone.Close()
|
||||
|
||||
rr := resolverAndDelay{
|
||||
name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)},
|
||||
}
|
||||
|
||||
resp, err := fwd.send(context.Background(), fq, rr)
|
||||
if err != nil {
|
||||
t.Fatalf("error making request: %v", err)
|
||||
}
|
||||
if !bytes.Equal(resp, largeResponse) {
|
||||
t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
|
||||
}
|
||||
if !sawTCPRequest.Load() {
|
||||
t.Errorf("DNS server never saw TCP request")
|
||||
}
|
||||
if !sawUDPRequest.Load() {
|
||||
t.Errorf("DNS server never saw UDP request")
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user