tailscale/net/dns/resolver/forwarder_test.go
Andrew Dunham 7fe6e50858
Some checks failed
checklocks / checklocks (push) Waiting to run
Dockerfile build / deploy (push) Waiting to run
CI / go_generate (push) Waiting to run
CI / race-root-integration (1/4) (push) Waiting to run
CI / race-root-integration (2/4) (push) Waiting to run
CI / race-root-integration (3/4) (push) Waiting to run
CI / race-root-integration (4/4) (push) Waiting to run
CI / test (-coverprofile=/tmp/coverage.out, amd64) (push) Waiting to run
CI / test (-race, amd64, 1/3) (push) Waiting to run
CI / test (-race, amd64, 2/3) (push) Waiting to run
CI / test (-race, amd64, 3/3) (push) Waiting to run
CI / test (386) (push) Waiting to run
CI / windows (push) Waiting to run
CI / privileged (push) Waiting to run
CI / vm (push) Waiting to run
CI / race-build (push) Waiting to run
CI / go_mod_tidy (push) Waiting to run
CI / licenses (push) Waiting to run
CI / cross (386, linux) (push) Waiting to run
CI / cross (amd64, darwin) (push) Waiting to run
CI / cross (amd64, freebsd) (push) Waiting to run
CI / cross (amd64, openbsd) (push) Waiting to run
CI / cross (amd64, windows) (push) Waiting to run
CI / cross (arm, 5, linux) (push) Waiting to run
CI / cross (arm, 7, linux) (push) Waiting to run
CI / cross (arm64, darwin) (push) Waiting to run
CI / cross (arm64, linux) (push) Waiting to run
CI / cross (arm64, windows) (push) Waiting to run
CI / cross (loong64, linux) (push) Waiting to run
CI / ios (push) Waiting to run
CI / crossmin (amd64, plan9) (push) Waiting to run
CI / staticcheck (386, windows) (push) Waiting to run
CI / crossmin (ppc64, aix) (push) Waiting to run
CI / android (push) Waiting to run
CI / wasm (push) Waiting to run
CI / tailscale_go (push) Waiting to run
CI / fuzz (push) Waiting to run
CI / depaware (push) Waiting to run
CI / staticcheck (amd64, darwin) (push) Waiting to run
CI / staticcheck (amd64, linux) (push) Waiting to run
CI / staticcheck (amd64, windows) (push) Waiting to run
CI / notify_slack (push) Blocked by required conditions
CI / check_mergeability (push) Blocked by required conditions
CodeQL / Analyze (go) (push) Has been cancelled
net/dns/resolver: fix test flake
Updates #13902

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: Ib2def19caad17367e9a31786ac969278e65f51c6
2024-10-24 13:36:57 -05:00

879 lines
23 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package resolver
import (
"bytes"
"context"
"encoding/binary"
"flag"
"fmt"
"io"
"net"
"net/netip"
"os"
"reflect"
"slices"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
dns "golang.org/x/net/dns/dnsmessage"
"tailscale.com/control/controlknobs"
"tailscale.com/envknob"
"tailscale.com/health"
"tailscale.com/net/netmon"
"tailscale.com/net/tsdial"
"tailscale.com/tstest"
"tailscale.com/types/dnstype"
)
func (rr resolverAndDelay) String() string {
return fmt.Sprintf("%v+%v", rr.name, rr.startDelay)
}
func TestResolversWithDelays(t *testing.T) {
// query
q := func(ss ...string) (ipps []*dnstype.Resolver) {
for _, host := range ss {
ipps = append(ipps, &dnstype.Resolver{Addr: host})
}
return
}
// output
o := func(ss ...string) (rr []resolverAndDelay) {
for _, s := range ss {
var d time.Duration
s, durStr, hasPlus := strings.Cut(s, "+")
if hasPlus {
var err error
d, err = time.ParseDuration(durStr)
if err != nil {
panic(fmt.Sprintf("parsing duration in %q: %v", s, err))
}
}
rr = append(rr, resolverAndDelay{
name: &dnstype.Resolver{Addr: s},
startDelay: d,
})
}
return
}
tests := []struct {
name string
in []*dnstype.Resolver
want []resolverAndDelay
}{
{
name: "unknown-no-delays",
in: q("1.2.3.4", "2.3.4.5"),
want: o("1.2.3.4", "2.3.4.5"),
},
{
name: "google-all-ipv4",
in: q("8.8.8.8", "8.8.4.4"),
want: o("https://dns.google/dns-query", "8.8.8.8+0.5s", "8.8.4.4+0.7s"),
},
{
name: "google-only-ipv6",
in: q("2001:4860:4860::8888", "2001:4860:4860::8844"),
want: o("https://dns.google/dns-query", "2001:4860:4860::8888+0.5s", "2001:4860:4860::8844+0.7s"),
},
{
name: "google-all-four",
in: q("8.8.8.8", "8.8.4.4", "2001:4860:4860::8888", "2001:4860:4860::8844"),
want: o("https://dns.google/dns-query", "8.8.8.8+0.5s", "8.8.4.4+0.7s", "2001:4860:4860::8888+0.5s", "2001:4860:4860::8844+0.7s"),
},
{
name: "quad9-one-v4-one-v6",
in: q("9.9.9.9", "2620:fe::fe"),
want: o("https://dns.quad9.net/dns-query", "9.9.9.9+0.5s", "2620:fe::fe+0.5s"),
},
{
name: "nextdns-ipv6-expand",
in: q("2a07:a8c0::c3:a884"),
want: o("https://dns.nextdns.io/c3a884"),
},
{
name: "nextdns-doh-input",
in: q("https://dns.nextdns.io/c3a884"),
want: o("https://dns.nextdns.io/c3a884"),
},
{
name: "controld-ipv6-expand",
in: q("2606:1a40:0:6:7b5b:5949:35ad:0"),
want: o("https://dns.controld.com/hyq3ipr2ct"),
},
{
name: "controld-doh-input",
in: q("https://dns.controld.com/hyq3ipr2ct"),
want: o("https://dns.controld.com/hyq3ipr2ct"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := resolversWithDelays(tt.in)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("got %v; want %v", got, tt.want)
}
})
}
}
func TestGetRCode(t *testing.T) {
tests := []struct {
name string
packet []byte
want dns.RCode
}{
{
name: "empty",
packet: []byte{},
want: dns.RCode(5),
},
{
name: "too-short",
packet: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
want: dns.RCode(5),
},
{
name: "noerror",
packet: []byte{0xC4, 0xFE, 0x81, 0xA0, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01},
want: dns.RCode(0),
},
{
name: "refused",
packet: []byte{0xee, 0xa1, 0x81, 0x05, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
want: dns.RCode(5),
},
{
name: "nxdomain",
packet: []byte{0x34, 0xf4, 0x81, 0x83, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01},
want: dns.RCode(3),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := getRCode(tt.packet)
if got != tt.want {
t.Errorf("got %d; want %d", got, tt.want)
}
})
}
}
var testDNS = flag.Bool("test-dns", false, "run tests that require a working DNS server")
func TestGetKnownDoHClientForProvider(t *testing.T) {
var fwd forwarder
c, ok := fwd.getKnownDoHClientForProvider("https://dns.google/dns-query")
if !ok {
t.Fatal("not found")
}
if !*testDNS {
t.Skip("skipping without --test-dns")
}
res, err := c.Head("https://dns.google/")
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()
t.Logf("Got: %+v", res)
}
func BenchmarkNameFromQuery(b *testing.B) {
builder := dns.NewBuilder(nil, dns.Header{})
builder.StartQuestions()
builder.Question(dns.Question{
Name: dns.MustNewName("foo.example."),
Type: dns.TypeA,
Class: dns.ClassINET,
})
msg, err := builder.Finish()
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
b.ReportAllocs()
for range b.N {
_, _, err := nameFromQuery(msg)
if err != nil {
b.Fatal(err)
}
}
}
// Reproduces https://github.com/tailscale/tailscale/issues/2533
// Fixed by https://github.com/tailscale/tailscale/commit/f414a9cc01f3264912513d07c0244ff4f3e4ba54
//
// NOTE: fuzz tests act like unit tests when run without `-fuzz`
func FuzzClampEDNSSize(f *testing.F) {
// Empty DNS packet
f.Add([]byte{
// query id
0x12, 0x34,
// flags: standard query, recurse
0x01, 0x20,
// num questions
0x00, 0x00,
// num answers
0x00, 0x00,
// num authority RRs
0x00, 0x00,
// num additional RRs
0x00, 0x00,
})
// Empty OPT
f.Add([]byte{
// header
0xaf, 0x66, 0x01, 0x20, 0x00, 0x01, 0x00, 0x00,
0x00, 0x00, 0x00, 0x01,
// query
0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f,
0x6d, 0x00, 0x00, 0x01, 0x00, 0x01,
// OPT
0x00, // name: <root>
0x00, 0x29, // type: OPT
0x10, 0x00, // UDP payload size
0x00, // higher bits in extended RCODE
0x00, // EDNS0 version
0x80, 0x00, // "Z" field
0x00, 0x00, // data length
})
// Query for "google.com"
f.Add([]byte{
// header
0xaf, 0x66, 0x01, 0x20, 0x00, 0x01, 0x00, 0x00,
0x00, 0x00, 0x00, 0x01,
// query
0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f,
0x6d, 0x00, 0x00, 0x01, 0x00, 0x01,
// OPT
0x00, 0x00, 0x29, 0x10, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00,
0x0c, 0x00, 0x0a, 0x00, 0x08, 0x62, 0x18, 0x1a, 0xcb, 0x19,
0xd7, 0xee, 0x23,
})
f.Fuzz(func(t *testing.T, data []byte) {
clampEDNSSize(data, maxResponseBytes)
})
}
type testDNSServerOptions struct {
SkipUDP bool
SkipTCP bool
}
func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, onRequest func(bool, []byte)) (port uint16) {
if opts != nil && opts.SkipUDP && opts.SkipTCP {
tb.Fatal("cannot skip both UDP and TCP servers")
}
logf := tstest.WhileTestRunningLogger(tb)
tcpResponse := make([]byte, len(response)+2)
binary.BigEndian.PutUint16(tcpResponse, uint16(len(response)))
copy(tcpResponse[2:], response)
// Repeatedly listen until we can get the same port.
const tries = 25
var (
tcpLn *net.TCPListener
udpLn *net.UDPConn
err error
)
for try := 0; try < tries; try++ {
if tcpLn != nil {
tcpLn.Close()
tcpLn = nil
}
tcpLn, err = net.ListenTCP("tcp4", &net.TCPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 0, // Choose one
})
if err != nil {
tb.Fatal(err)
}
udpLn, err = net.ListenUDP("udp4", &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: tcpLn.Addr().(*net.TCPAddr).Port,
})
if err == nil {
break
}
}
if tcpLn == nil || udpLn == nil {
if tcpLn != nil {
tcpLn.Close()
}
if udpLn != nil {
udpLn.Close()
}
// Skip instead of being fatal to avoid flaking on extremely
// heavily-loaded CI systems.
tb.Skipf("failed to listen on same port for TCP/UDP after %d tries", tries)
}
port = uint16(tcpLn.Addr().(*net.TCPAddr).Port)
handleConn := func(conn net.Conn) {
defer conn.Close()
// Read the length header, then the buffer
var length uint16
if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
logf("error reading length header: %v", err)
return
}
req := make([]byte, length)
n, err := io.ReadFull(conn, req)
if err != nil {
logf("error reading query: %v", err)
return
}
req = req[:n]
onRequest(true, req)
// Write response
if _, err := conn.Write(tcpResponse); err != nil {
logf("error writing response: %v", err)
return
}
}
var wg sync.WaitGroup
if opts == nil || !opts.SkipTCP {
wg.Add(1)
go func() {
defer wg.Done()
for {
conn, err := tcpLn.Accept()
if err != nil {
return
}
go handleConn(conn)
}
}()
}
handleUDP := func(addr netip.AddrPort, req []byte) {
onRequest(false, req)
if _, err := udpLn.WriteToUDPAddrPort(response, addr); err != nil {
logf("error writing response: %v", err)
}
}
if opts == nil || !opts.SkipUDP {
wg.Add(1)
go func() {
defer wg.Done()
for {
buf := make([]byte, 65535)
n, addr, err := udpLn.ReadFromUDPAddrPort(buf)
if err != nil {
return
}
buf = buf[:n]
go handleUDP(addr, buf)
}
}()
}
tb.Cleanup(func() {
tcpLn.Close()
udpLn.Close()
logf("waiting for listeners to finish...")
wg.Wait()
})
return
}
func enableDebug(tb testing.TB) {
const debugKnob = "TS_DEBUG_DNS_FORWARD_SEND"
oldVal := os.Getenv(debugKnob)
envknob.Setenv(debugKnob, "true")
tb.Cleanup(func() { envknob.Setenv(debugKnob, oldVal) })
}
func makeLargeResponse(tb testing.TB, domain string) (request, response []byte) {
name := dns.MustNewName(domain)
builder := dns.NewBuilder(nil, dns.Header{Response: true})
builder.StartQuestions()
builder.Question(dns.Question{
Name: name,
Type: dns.TypeA,
Class: dns.ClassINET,
})
builder.StartAnswers()
for i := range 120 {
builder.AResource(dns.ResourceHeader{
Name: name,
Class: dns.ClassINET,
TTL: 300,
}, dns.AResource{
A: [4]byte{127, 0, 0, byte(i)},
})
}
var err error
response, err = builder.Finish()
if err != nil {
tb.Fatal(err)
}
if len(response) <= maxResponseBytes {
tb.Fatalf("got len(largeResponse)=%d, want > %d", len(response), maxResponseBytes)
}
// Our request is a single A query for the domain in the answer, above.
builder = dns.NewBuilder(nil, dns.Header{})
builder.StartQuestions()
builder.Question(dns.Question{
Name: dns.MustNewName(domain),
Type: dns.TypeA,
Class: dns.ClassINET,
})
request, err = builder.Finish()
if err != nil {
tb.Fatal(err)
}
return
}
func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) ([]byte, error) {
logf := tstest.WhileTestRunningLogger(tb)
netMon, err := netmon.New(logf)
if err != nil {
tb.Fatal(err)
}
var dialer tsdial.Dialer
dialer.SetNetMon(netMon)
fwd := newForwarder(logf, netMon, nil, &dialer, new(health.Tracker), nil)
if modify != nil {
modify(fwd)
}
resolvers := make([]resolverAndDelay, len(ports))
for i, port := range ports {
resolvers[i].name = &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)}
}
rpkt := packet{
bs: request,
family: "tcp",
addr: netip.MustParseAddrPort("127.0.0.1:12345"),
}
rchan := make(chan packet, 1)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
tb.Cleanup(cancel)
err = fwd.forwardWithDestChan(ctx, rpkt, rchan, resolvers...)
select {
case res := <-rchan:
return res.bs, err
case <-ctx.Done():
return nil, ctx.Err()
}
}
// makeTestRequest returns a new TypeA request for the given domain.
func makeTestRequest(tb testing.TB, domain string) []byte {
tb.Helper()
name := dns.MustNewName(domain)
builder := dns.NewBuilder(nil, dns.Header{})
builder.StartQuestions()
builder.Question(dns.Question{
Name: name,
Type: dns.TypeA,
Class: dns.ClassINET,
})
request, err := builder.Finish()
if err != nil {
tb.Fatal(err)
}
return request
}
// makeTestResponse returns a new Type A response for the given domain,
// with the specified status code and zero or more addresses.
func makeTestResponse(tb testing.TB, domain string, code dns.RCode, addrs ...netip.Addr) []byte {
tb.Helper()
name := dns.MustNewName(domain)
builder := dns.NewBuilder(nil, dns.Header{
Response: true,
Authoritative: true,
RCode: code,
})
builder.StartQuestions()
q := dns.Question{
Name: name,
Type: dns.TypeA,
Class: dns.ClassINET,
}
builder.Question(q)
if len(addrs) > 0 {
builder.StartAnswers()
for _, addr := range addrs {
builder.AResource(dns.ResourceHeader{
Name: q.Name,
Class: q.Class,
TTL: 120,
}, dns.AResource{
A: addr.As4(),
})
}
}
response, err := builder.Finish()
if err != nil {
tb.Fatal(err)
}
return response
}
func mustRunTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) []byte {
resp, err := runTestQuery(tb, request, modify, ports...)
if err != nil {
tb.Fatalf("error making request: %v", err)
}
return resp
}
func TestForwarderTCPFallback(t *testing.T) {
enableDebug(t)
const domain = "large-dns-response.tailscale.com."
// Make a response that's very large, containing a bunch of localhost addresses.
request, largeResponse := makeLargeResponse(t, domain)
var sawTCPRequest atomic.Bool
port := runDNSServer(t, nil, largeResponse, func(isTCP bool, gotRequest []byte) {
if isTCP {
t.Logf("saw TCP request")
sawTCPRequest.Store(true)
} else {
t.Logf("saw UDP request")
}
if !bytes.Equal(request, gotRequest) {
t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request)
}
})
resp := mustRunTestQuery(t, request, nil, port)
if !bytes.Equal(resp, largeResponse) {
t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
}
if !sawTCPRequest.Load() {
t.Errorf("DNS server never saw TCP request")
}
// NOTE: can't assert that we see a UDP request here since we might
// race and run the TCP query first. We test the UDP codepath in
// TestForwarderTCPFallbackDisabled below, though.
}
// Test to ensure that if the UDP listener is unresponsive, we always make a
// TCP request even if we never get a response.
func TestForwarderTCPFallbackTimeout(t *testing.T) {
enableDebug(t)
const domain = "large-dns-response.tailscale.com."
// Make a response that's very large, containing a bunch of localhost addresses.
request, largeResponse := makeLargeResponse(t, domain)
var sawTCPRequest atomic.Bool
opts := &testDNSServerOptions{SkipUDP: true}
port := runDNSServer(t, opts, largeResponse, func(isTCP bool, gotRequest []byte) {
if isTCP {
t.Logf("saw TCP request")
sawTCPRequest.Store(true)
} else {
t.Error("saw unexpected UDP request")
}
if !bytes.Equal(request, gotRequest) {
t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request)
}
})
resp := mustRunTestQuery(t, request, nil, port)
if !bytes.Equal(resp, largeResponse) {
t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
}
if !sawTCPRequest.Load() {
t.Errorf("DNS server never saw TCP request")
}
}
func TestForwarderTCPFallbackDisabled(t *testing.T) {
enableDebug(t)
const domain = "large-dns-response.tailscale.com."
// Make a response that's very large, containing a bunch of localhost addresses.
request, largeResponse := makeLargeResponse(t, domain)
var sawUDPRequest atomic.Bool
port := runDNSServer(t, nil, largeResponse, func(isTCP bool, gotRequest []byte) {
if isTCP {
t.Error("saw unexpected TCP request")
} else {
t.Logf("saw UDP request")
sawUDPRequest.Store(true)
}
if !bytes.Equal(request, gotRequest) {
t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request)
}
})
resp := mustRunTestQuery(t, request, func(fwd *forwarder) {
// Disable retries for this test.
fwd.controlKnobs = &controlknobs.Knobs{}
fwd.controlKnobs.DisableDNSForwarderTCPRetries.Store(true)
}, port)
wantResp := append([]byte(nil), largeResponse[:maxResponseBytes]...)
// Set the truncated flag on the expected response, since that's what we expect.
flags := binary.BigEndian.Uint16(wantResp[2:4])
flags |= dnsFlagTruncated
binary.BigEndian.PutUint16(wantResp[2:4], flags)
if !bytes.Equal(resp, wantResp) {
t.Errorf("invalid response\ngot (%d): %+v\nwant (%d): %+v", len(resp), resp, len(wantResp), wantResp)
}
if !sawUDPRequest.Load() {
t.Errorf("DNS server never saw UDP request")
}
}
// Test to ensure that we propagate DNS errors
func TestForwarderTCPFallbackError(t *testing.T) {
enableDebug(t)
const domain = "error-response.tailscale.com."
// Our response is a SERVFAIL
response := makeTestResponse(t, domain, dns.RCodeServerFailure)
// Our request is a single A query for the domain in the answer, above.
request := makeTestRequest(t, domain)
var sawRequest atomic.Bool
port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) {
sawRequest.Store(true)
if !bytes.Equal(request, gotRequest) {
t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request)
}
})
resp, err := runTestQuery(t, request, nil, port)
if !sawRequest.Load() {
t.Error("did not see DNS request")
}
if err != nil {
t.Fatalf("wanted nil, got %v", err)
}
var parser dns.Parser
respHeader, err := parser.Start(resp)
if err != nil {
t.Fatalf("parser.Start() failed: %v", err)
}
if got, want := respHeader.RCode, dns.RCodeServerFailure; got != want {
t.Errorf("wanted %v, got %v", want, got)
}
}
// Test to ensure that if we have more than one resolver, and at least one of them
// returns a successful response, we propagate it.
func TestForwarderWithManyResolvers(t *testing.T) {
enableDebug(t)
const domain = "example.com."
request := makeTestRequest(t, domain)
tests := []struct {
name string
responses [][]byte // upstream responses
wantResponses [][]byte // we should receive one of these from the forwarder
}{
{
name: "Success",
responses: [][]byte{ // All upstream servers returned successful, but different, response.
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.2")),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.3")),
},
wantResponses: [][]byte{ // We may forward whichever response is received first.
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.2")),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.3")),
},
},
{
name: "ServFail",
responses: [][]byte{ // All upstream servers returned a SERVFAIL.
makeTestResponse(t, domain, dns.RCodeServerFailure),
makeTestResponse(t, domain, dns.RCodeServerFailure),
makeTestResponse(t, domain, dns.RCodeServerFailure),
},
wantResponses: [][]byte{
makeTestResponse(t, domain, dns.RCodeServerFailure),
},
},
{
name: "ServFail+Success",
responses: [][]byte{ // All upstream servers fail except for one.
makeTestResponse(t, domain, dns.RCodeServerFailure),
makeTestResponse(t, domain, dns.RCodeServerFailure),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
makeTestResponse(t, domain, dns.RCodeServerFailure),
},
wantResponses: [][]byte{ // We should forward the successful response.
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
},
},
{
name: "NXDomain",
responses: [][]byte{ // All upstream servers returned NXDOMAIN.
makeTestResponse(t, domain, dns.RCodeNameError),
makeTestResponse(t, domain, dns.RCodeNameError),
makeTestResponse(t, domain, dns.RCodeNameError),
},
wantResponses: [][]byte{
makeTestResponse(t, domain, dns.RCodeNameError),
},
},
{
name: "NXDomain+Success",
responses: [][]byte{ // All upstream servers returned NXDOMAIN except for one.
makeTestResponse(t, domain, dns.RCodeNameError),
makeTestResponse(t, domain, dns.RCodeNameError),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
},
wantResponses: [][]byte{ // However, only SERVFAIL are considered to be errors. Therefore, we may forward any response.
makeTestResponse(t, domain, dns.RCodeNameError),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
},
},
{
name: "Refused",
responses: [][]byte{ // All upstream servers return different failures.
makeTestResponse(t, domain, dns.RCodeRefused),
makeTestResponse(t, domain, dns.RCodeRefused),
makeTestResponse(t, domain, dns.RCodeRefused),
makeTestResponse(t, domain, dns.RCodeRefused),
makeTestResponse(t, domain, dns.RCodeRefused),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
},
wantResponses: [][]byte{ // Refused is not considered to be an error and can be forwarded.
makeTestResponse(t, domain, dns.RCodeRefused),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
},
},
{
name: "MixFail",
responses: [][]byte{ // All upstream servers return different failures.
makeTestResponse(t, domain, dns.RCodeServerFailure),
makeTestResponse(t, domain, dns.RCodeNameError),
makeTestResponse(t, domain, dns.RCodeRefused),
},
wantResponses: [][]byte{ // Both NXDomain and Refused can be forwarded.
makeTestResponse(t, domain, dns.RCodeNameError),
makeTestResponse(t, domain, dns.RCodeRefused),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ports := make([]uint16, len(tt.responses))
for i := range tt.responses {
ports[i] = runDNSServer(t, nil, tt.responses[i], func(isTCP bool, gotRequest []byte) {})
}
gotResponse, err := runTestQuery(t, request, nil, ports...)
if err != nil {
t.Fatalf("wanted nil, got %v", err)
}
responseOk := slices.ContainsFunc(tt.wantResponses, func(wantResponse []byte) bool {
return slices.Equal(gotResponse, wantResponse)
})
if !responseOk {
t.Errorf("invalid response\ngot: %+v\nwant: %+v", gotResponse, tt.wantResponses[0])
}
})
}
}
// mdnsResponder at minimum has an expectation that NXDOMAIN must include the
// question, otherwise it will penalize our server (#13511).
func TestNXDOMAINIncludesQuestion(t *testing.T) {
var domain = "lb._dns-sd._udp.example.org."
// Our response is a NXDOMAIN
response := func() []byte {
name := dns.MustNewName(domain)
builder := dns.NewBuilder(nil, dns.Header{
Response: true,
RCode: dns.RCodeNameError,
})
builder.StartQuestions()
builder.Question(dns.Question{
Name: name,
Type: dns.TypePTR,
Class: dns.ClassINET,
})
response, err := builder.Finish()
if err != nil {
t.Fatal(err)
}
return response
}()
// Our request is a single PTR query for the domain in the answer, above.
request := func() []byte {
builder := dns.NewBuilder(nil, dns.Header{})
builder.StartQuestions()
builder.Question(dns.Question{
Name: dns.MustNewName(domain),
Type: dns.TypePTR,
Class: dns.ClassINET,
})
request, err := builder.Finish()
if err != nil {
t.Fatal(err)
}
return request
}()
port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) {
})
res, err := runTestQuery(t, request, nil, port)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(res, response) {
t.Errorf("invalid response\ngot: %+v\nwant: %+v", res, response)
}
}