mirror of
https://github.com/tailscale/tailscale.git
synced 2025-08-25 20:57:31 +00:00
net/dns/resolver: race UDP and TCP queries (#9544)
Instead of just falling back to making a TCP query to an upstream DNS server when the UDP query returns a truncated query, also start a TCP query in parallel with the UDP query after a given race timeout. This ensures that if the upstream DNS server does not reply over UDP (or if the response packet is blocked, or there's an error), we can still make queries if the server replies to TCP queries. This also adds a new package, util/race, to contain the logic required for racing two different functions and returning the first non-error answer. Updates tailscale/corp#14809 Signed-off-by: Andrew Dunham <andrew@du.nham.ca> Change-Id: I4311702016c1093b1beaa31b135da1def6d86316
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -21,6 +22,7 @@ import (
|
||||
"time"
|
||||
|
||||
dns "golang.org/x/net/dns/dnsmessage"
|
||||
"tailscale.com/control/controlknobs"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/net/netmon"
|
||||
"tailscale.com/net/tsdial"
|
||||
@@ -253,7 +255,16 @@ func FuzzClampEDNSSize(f *testing.F) {
|
||||
})
|
||||
}
|
||||
|
||||
func runDNSServer(tb testing.TB, response []byte, onRequest func(bool, []byte)) (port uint16) {
|
||||
type testDNSServerOptions struct {
|
||||
SkipUDP bool
|
||||
SkipTCP bool
|
||||
}
|
||||
|
||||
func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, onRequest func(bool, []byte)) (port uint16) {
|
||||
if opts != nil && opts.SkipUDP && opts.SkipTCP {
|
||||
tb.Fatal("cannot skip both UDP and TCP servers")
|
||||
}
|
||||
|
||||
tcpResponse := make([]byte, len(response)+2)
|
||||
binary.BigEndian.PutUint16(tcpResponse, uint16(len(response)))
|
||||
copy(tcpResponse[2:], response)
|
||||
@@ -327,17 +338,20 @@ func runDNSServer(tb testing.TB, response []byte, onRequest func(bool, []byte))
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
conn, err := tcpLn.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
|
||||
if opts == nil || !opts.SkipTCP {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
conn, err := tcpLn.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go handleConn(conn)
|
||||
}
|
||||
go handleConn(conn)
|
||||
}
|
||||
}()
|
||||
}()
|
||||
}
|
||||
|
||||
handleUDP := func(addr netip.AddrPort, req []byte) {
|
||||
onRequest(false, req)
|
||||
@@ -346,19 +360,21 @@ func runDNSServer(tb testing.TB, response []byte, onRequest func(bool, []byte))
|
||||
}
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
buf := make([]byte, 65535)
|
||||
n, addr, err := udpLn.ReadFromUDPAddrPort(buf)
|
||||
if err != nil {
|
||||
return
|
||||
if opts == nil || !opts.SkipUDP {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
buf := make([]byte, 65535)
|
||||
n, addr, err := udpLn.ReadFromUDPAddrPort(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
buf = buf[:n]
|
||||
go handleUDP(addr, buf)
|
||||
}
|
||||
buf = buf[:n]
|
||||
go handleUDP(addr, buf)
|
||||
}
|
||||
}()
|
||||
}()
|
||||
}
|
||||
|
||||
tb.Cleanup(func() {
|
||||
tcpLn.Close()
|
||||
@@ -369,84 +385,72 @@ func runDNSServer(tb testing.TB, response []byte, onRequest func(bool, []byte))
|
||||
return
|
||||
}
|
||||
|
||||
func TestForwarderTCPFallback(t *testing.T) {
|
||||
func enableDebug(tb testing.TB) {
|
||||
const debugKnob = "TS_DEBUG_DNS_FORWARD_SEND"
|
||||
oldVal := os.Getenv(debugKnob)
|
||||
envknob.Setenv(debugKnob, "true")
|
||||
t.Cleanup(func() { envknob.Setenv(debugKnob, oldVal) })
|
||||
tb.Cleanup(func() { envknob.Setenv(debugKnob, oldVal) })
|
||||
}
|
||||
|
||||
const domain = "large-dns-response.tailscale.com."
|
||||
func makeLargeResponse(tb testing.TB, domain string) (request, response []byte) {
|
||||
name := dns.MustNewName(domain)
|
||||
|
||||
// Make a response that's very large, containing a bunch of localhost addresses.
|
||||
largeResponse := func() []byte {
|
||||
name := dns.MustNewName(domain)
|
||||
|
||||
builder := dns.NewBuilder(nil, dns.Header{})
|
||||
builder.StartQuestions()
|
||||
builder.Question(dns.Question{
|
||||
builder := dns.NewBuilder(nil, dns.Header{})
|
||||
builder.StartQuestions()
|
||||
builder.Question(dns.Question{
|
||||
Name: name,
|
||||
Type: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
})
|
||||
builder.StartAnswers()
|
||||
for i := 0; i < 120; i++ {
|
||||
builder.AResource(dns.ResourceHeader{
|
||||
Name: name,
|
||||
Type: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
TTL: 300,
|
||||
}, dns.AResource{
|
||||
A: [4]byte{127, 0, 0, byte(i)},
|
||||
})
|
||||
builder.StartAnswers()
|
||||
for i := 0; i < 120; i++ {
|
||||
builder.AResource(dns.ResourceHeader{
|
||||
Name: name,
|
||||
Class: dns.ClassINET,
|
||||
TTL: 300,
|
||||
}, dns.AResource{
|
||||
A: [4]byte{127, 0, 0, byte(i)},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
msg, err := builder.Finish()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return msg
|
||||
}()
|
||||
if len(largeResponse) <= maxResponseBytes {
|
||||
t.Fatalf("got len(largeResponse)=%d, want > %d", len(largeResponse), maxResponseBytes)
|
||||
var err error
|
||||
response, err = builder.Finish()
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
if len(response) <= maxResponseBytes {
|
||||
tb.Fatalf("got len(largeResponse)=%d, want > %d", len(response), maxResponseBytes)
|
||||
}
|
||||
|
||||
// Our request is a single A query for the domain in the answer, above.
|
||||
request := func() []byte {
|
||||
builder := dns.NewBuilder(nil, dns.Header{})
|
||||
builder.StartQuestions()
|
||||
builder.Question(dns.Question{
|
||||
Name: dns.MustNewName(domain),
|
||||
Type: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
})
|
||||
msg, err := builder.Finish()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return msg
|
||||
}()
|
||||
|
||||
var sawUDPRequest, sawTCPRequest atomic.Bool
|
||||
port := runDNSServer(t, largeResponse, func(isTCP bool, gotRequest []byte) {
|
||||
if isTCP {
|
||||
sawTCPRequest.Store(true)
|
||||
} else {
|
||||
sawUDPRequest.Store(true)
|
||||
}
|
||||
|
||||
if !bytes.Equal(request, gotRequest) {
|
||||
t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request)
|
||||
}
|
||||
builder = dns.NewBuilder(nil, dns.Header{})
|
||||
builder.StartQuestions()
|
||||
builder.Question(dns.Question{
|
||||
Name: dns.MustNewName(domain),
|
||||
Type: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
})
|
||||
|
||||
netMon, err := netmon.New(t.Logf)
|
||||
request, err = builder.Finish()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) ([]byte, error) {
|
||||
netMon, err := netmon.New(tb.Logf)
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
var dialer tsdial.Dialer
|
||||
dialer.SetNetMon(netMon)
|
||||
|
||||
fwd := newForwarder(t.Logf, netMon, nil, &dialer, nil)
|
||||
fwd := newForwarder(tb.Logf, netMon, nil, &dialer, nil)
|
||||
if modify != nil {
|
||||
modify(fwd)
|
||||
}
|
||||
|
||||
fq := &forwardQuery{
|
||||
txid: getTxID(request),
|
||||
@@ -459,10 +463,41 @@ func TestForwarderTCPFallback(t *testing.T) {
|
||||
name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)},
|
||||
}
|
||||
|
||||
resp, err := fwd.send(context.Background(), fq, rr)
|
||||
return fwd.send(context.Background(), fq, rr)
|
||||
}
|
||||
|
||||
func mustRunTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) []byte {
|
||||
resp, err := runTestQuery(tb, port, request, modify)
|
||||
if err != nil {
|
||||
t.Fatalf("error making request: %v", err)
|
||||
tb.Fatalf("error making request: %v", err)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func TestForwarderTCPFallback(t *testing.T) {
|
||||
enableDebug(t)
|
||||
|
||||
const domain = "large-dns-response.tailscale.com."
|
||||
|
||||
// Make a response that's very large, containing a bunch of localhost addresses.
|
||||
request, largeResponse := makeLargeResponse(t, domain)
|
||||
|
||||
var sawUDPRequest, sawTCPRequest atomic.Bool
|
||||
port := runDNSServer(t, nil, largeResponse, func(isTCP bool, gotRequest []byte) {
|
||||
if isTCP {
|
||||
t.Logf("saw TCP request")
|
||||
sawTCPRequest.Store(true)
|
||||
} else {
|
||||
t.Logf("saw UDP request")
|
||||
sawUDPRequest.Store(true)
|
||||
}
|
||||
|
||||
if !bytes.Equal(request, gotRequest) {
|
||||
t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request)
|
||||
}
|
||||
})
|
||||
|
||||
resp := mustRunTestQuery(t, port, request, nil)
|
||||
if !bytes.Equal(resp, largeResponse) {
|
||||
t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
|
||||
}
|
||||
@@ -473,3 +508,141 @@ func TestForwarderTCPFallback(t *testing.T) {
|
||||
t.Errorf("DNS server never saw UDP request")
|
||||
}
|
||||
}
|
||||
|
||||
// Test to ensure that if the UDP listener is unresponsive, we always make a
|
||||
// TCP request even if we never get a response.
|
||||
func TestForwarderTCPFallbackTimeout(t *testing.T) {
|
||||
enableDebug(t)
|
||||
|
||||
const domain = "large-dns-response.tailscale.com."
|
||||
|
||||
// Make a response that's very large, containing a bunch of localhost addresses.
|
||||
request, largeResponse := makeLargeResponse(t, domain)
|
||||
|
||||
var sawTCPRequest atomic.Bool
|
||||
opts := &testDNSServerOptions{SkipUDP: true}
|
||||
port := runDNSServer(t, opts, largeResponse, func(isTCP bool, gotRequest []byte) {
|
||||
if isTCP {
|
||||
t.Logf("saw TCP request")
|
||||
sawTCPRequest.Store(true)
|
||||
} else {
|
||||
t.Error("saw unexpected UDP request")
|
||||
}
|
||||
|
||||
if !bytes.Equal(request, gotRequest) {
|
||||
t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request)
|
||||
}
|
||||
})
|
||||
|
||||
resp := mustRunTestQuery(t, port, request, nil)
|
||||
if !bytes.Equal(resp, largeResponse) {
|
||||
t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
|
||||
}
|
||||
if !sawTCPRequest.Load() {
|
||||
t.Errorf("DNS server never saw TCP request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwarderTCPFallbackDisabled(t *testing.T) {
|
||||
enableDebug(t)
|
||||
|
||||
const domain = "large-dns-response.tailscale.com."
|
||||
|
||||
// Make a response that's very large, containing a bunch of localhost addresses.
|
||||
request, largeResponse := makeLargeResponse(t, domain)
|
||||
|
||||
var sawUDPRequest atomic.Bool
|
||||
port := runDNSServer(t, nil, largeResponse, func(isTCP bool, gotRequest []byte) {
|
||||
if isTCP {
|
||||
t.Error("saw unexpected TCP request")
|
||||
} else {
|
||||
t.Logf("saw UDP request")
|
||||
sawUDPRequest.Store(true)
|
||||
}
|
||||
|
||||
if !bytes.Equal(request, gotRequest) {
|
||||
t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request)
|
||||
}
|
||||
})
|
||||
|
||||
resp := mustRunTestQuery(t, port, request, func(fwd *forwarder) {
|
||||
// Disable retries for this test.
|
||||
fwd.controlKnobs = &controlknobs.Knobs{}
|
||||
fwd.controlKnobs.DisableDNSForwarderTCPRetries.Store(true)
|
||||
})
|
||||
|
||||
wantResp := append([]byte(nil), largeResponse[:maxResponseBytes]...)
|
||||
|
||||
// Set the truncated flag on the expected response, since that's what we expect.
|
||||
flags := binary.BigEndian.Uint16(wantResp[2:4])
|
||||
flags |= dnsFlagTruncated
|
||||
binary.BigEndian.PutUint16(wantResp[2:4], flags)
|
||||
|
||||
if !bytes.Equal(resp, wantResp) {
|
||||
t.Errorf("invalid response\ngot (%d): %+v\nwant (%d): %+v", len(resp), resp, len(wantResp), wantResp)
|
||||
}
|
||||
if !sawUDPRequest.Load() {
|
||||
t.Errorf("DNS server never saw UDP request")
|
||||
}
|
||||
}
|
||||
|
||||
// Test to ensure that we propagate DNS errors
|
||||
func TestForwarderTCPFallbackError(t *testing.T) {
|
||||
enableDebug(t)
|
||||
|
||||
const domain = "error-response.tailscale.com."
|
||||
|
||||
// Our response is a SERVFAIL
|
||||
response := func() []byte {
|
||||
name := dns.MustNewName(domain)
|
||||
|
||||
builder := dns.NewBuilder(nil, dns.Header{
|
||||
RCode: dns.RCodeServerFailure,
|
||||
})
|
||||
builder.StartQuestions()
|
||||
builder.Question(dns.Question{
|
||||
Name: name,
|
||||
Type: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
})
|
||||
response, err := builder.Finish()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return response
|
||||
}()
|
||||
|
||||
// Our request is a single A query for the domain in the answer, above.
|
||||
request := func() []byte {
|
||||
builder := dns.NewBuilder(nil, dns.Header{})
|
||||
builder.StartQuestions()
|
||||
builder.Question(dns.Question{
|
||||
Name: dns.MustNewName(domain),
|
||||
Type: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
})
|
||||
request, err := builder.Finish()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return request
|
||||
}()
|
||||
|
||||
var sawRequest atomic.Bool
|
||||
port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) {
|
||||
sawRequest.Store(true)
|
||||
if !bytes.Equal(request, gotRequest) {
|
||||
t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request)
|
||||
}
|
||||
})
|
||||
|
||||
_, err := runTestQuery(t, port, request, nil)
|
||||
if !sawRequest.Load() {
|
||||
t.Error("did not see DNS request")
|
||||
}
|
||||
if err == nil {
|
||||
t.Error("wanted error, got nil")
|
||||
} else if !errors.Is(err, errServerFailure) {
|
||||
t.Errorf("wanted errServerFailure, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user