net/dns/resolver: forward SERVFAIL responses over PeerDNS

Cherry-pick #13691 into the release branch.

Updates #13571

Signed-off-by: Nick Khyl <nickk@tailscale.com>
This commit is contained in:
Nick Khyl 2024-10-15 10:35:15 -05:00 committed by Nick Khyl
parent f4d76fb46d
commit 78c8f7ec58
4 changed files with 210 additions and 69 deletions

View File

@ -487,6 +487,10 @@ func (f *forwarder) sendDoH(ctx context.Context, urlBase string, c *http.Client,
defer hres.Body.Close()
if hres.StatusCode != 200 {
metricDNSFwdDoHErrorStatus.Add(1)
if hres.StatusCode/100 == 5 {
// Translate 5xx HTTP server errors into SERVFAIL DNS responses.
return nil, fmt.Errorf("%w: %s", errServerFailure, hres.Status)
}
return nil, errors.New(hres.Status)
}
if ct := hres.Header.Get("Content-Type"); ct != dohType {
@ -916,10 +920,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo
metricDNSFwdDropBonjour.Add(1)
res, err := nxDomainResponse(query)
if err != nil {
f.logf("error parsing bonjour query: %v", err)
// Returning an error will cause an internal retry, there is
// nothing we can do if parsing failed. Just drop the packet.
return nil
return err
}
select {
case <-ctx.Done():
@ -951,10 +952,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo
res, err := servfailResponse(query)
if err != nil {
f.logf("building servfail response: %v", err)
// Returning an error will cause an internal retry, there is
// nothing we can do if parsing failed. Just drop the packet.
return nil
return err
}
select {
case <-ctx.Done():
@ -1053,6 +1051,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo
if verboseDNSForward() {
f.logf("forwarder response(%d, %v, %d) = %d, %v", fq.txid, typ, len(domain), len(res.bs), firstErr)
}
return nil
}
}
return firstErr

View File

@ -7,7 +7,6 @@ import (
"bytes"
"context"
"encoding/binary"
"errors"
"flag"
"fmt"
"io"
@ -450,7 +449,7 @@ func makeLargeResponse(tb testing.TB, domain string) (request, response []byte)
return
}
func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) ([]byte, error) {
func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) ([]byte, error) {
netMon, err := netmon.New(tb.Logf)
if err != nil {
tb.Fatal(err)
@ -464,8 +463,9 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa
modify(fwd)
}
rr := resolverAndDelay{
name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)},
resolvers := make([]resolverAndDelay, len(ports))
for i, port := range ports {
resolvers[i].name = &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)}
}
rpkt := packet{
@ -477,7 +477,7 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa
rchan := make(chan packet, 1)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
tb.Cleanup(cancel)
err = fwd.forwardWithDestChan(ctx, rpkt, rchan, rr)
err = fwd.forwardWithDestChan(ctx, rpkt, rchan, resolvers...)
select {
case res := <-rchan:
return res.bs, err
@ -486,8 +486,62 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa
}
}
func mustRunTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) []byte {
resp, err := runTestQuery(tb, port, request, modify)
// makeTestRequest returns a new TypeA request for the given domain.
func makeTestRequest(tb testing.TB, domain string) []byte {
tb.Helper()
name := dns.MustNewName(domain)
builder := dns.NewBuilder(nil, dns.Header{})
builder.StartQuestions()
builder.Question(dns.Question{
Name: name,
Type: dns.TypeA,
Class: dns.ClassINET,
})
request, err := builder.Finish()
if err != nil {
tb.Fatal(err)
}
return request
}
// makeTestResponse returns a new Type A response for the given domain,
// with the specified status code and zero or more addresses.
func makeTestResponse(tb testing.TB, domain string, code dns.RCode, addrs ...netip.Addr) []byte {
tb.Helper()
name := dns.MustNewName(domain)
builder := dns.NewBuilder(nil, dns.Header{
Response: true,
Authoritative: true,
RCode: code,
})
builder.StartQuestions()
q := dns.Question{
Name: name,
Type: dns.TypeA,
Class: dns.ClassINET,
}
builder.Question(q)
if len(addrs) > 0 {
builder.StartAnswers()
for _, addr := range addrs {
builder.AResource(dns.ResourceHeader{
Name: q.Name,
Class: q.Class,
TTL: 120,
}, dns.AResource{
A: addr.As4(),
})
}
}
response, err := builder.Finish()
if err != nil {
tb.Fatal(err)
}
return response
}
func mustRunTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) []byte {
resp, err := runTestQuery(tb, request, modify, ports...)
if err != nil {
tb.Fatalf("error making request: %v", err)
}
@ -516,7 +570,7 @@ func TestForwarderTCPFallback(t *testing.T) {
}
})
resp := mustRunTestQuery(t, port, request, nil)
resp := mustRunTestQuery(t, request, nil, port)
if !bytes.Equal(resp, largeResponse) {
t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
}
@ -554,7 +608,7 @@ func TestForwarderTCPFallbackTimeout(t *testing.T) {
}
})
resp := mustRunTestQuery(t, port, request, nil)
resp := mustRunTestQuery(t, request, nil, port)
if !bytes.Equal(resp, largeResponse) {
t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
}
@ -585,11 +639,11 @@ func TestForwarderTCPFallbackDisabled(t *testing.T) {
}
})
resp := mustRunTestQuery(t, port, request, func(fwd *forwarder) {
resp := mustRunTestQuery(t, request, func(fwd *forwarder) {
// Disable retries for this test.
fwd.controlKnobs = &controlknobs.Knobs{}
fwd.controlKnobs.DisableDNSForwarderTCPRetries.Store(true)
})
}, port)
wantResp := append([]byte(nil), largeResponse[:maxResponseBytes]...)
@ -613,41 +667,10 @@ func TestForwarderTCPFallbackError(t *testing.T) {
const domain = "error-response.tailscale.com."
// Our response is a SERVFAIL
response := func() []byte {
name := dns.MustNewName(domain)
builder := dns.NewBuilder(nil, dns.Header{
Response: true,
RCode: dns.RCodeServerFailure,
})
builder.StartQuestions()
builder.Question(dns.Question{
Name: name,
Type: dns.TypeA,
Class: dns.ClassINET,
})
response, err := builder.Finish()
if err != nil {
t.Fatal(err)
}
return response
}()
response := makeTestResponse(t, domain, dns.RCodeServerFailure)
// Our request is a single A query for the domain in the answer, above.
request := func() []byte {
builder := dns.NewBuilder(nil, dns.Header{})
builder.StartQuestions()
builder.Question(dns.Question{
Name: dns.MustNewName(domain),
Type: dns.TypeA,
Class: dns.ClassINET,
})
request, err := builder.Finish()
if err != nil {
t.Fatal(err)
}
return request
}()
request := makeTestRequest(t, domain)
var sawRequest atomic.Bool
port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) {
@ -657,14 +680,141 @@ func TestForwarderTCPFallbackError(t *testing.T) {
}
})
_, err := runTestQuery(t, port, request, nil)
resp, err := runTestQuery(t, request, nil, port)
if !sawRequest.Load() {
t.Error("did not see DNS request")
}
if err == nil {
t.Error("wanted error, got nil")
} else if !errors.Is(err, errServerFailure) {
t.Errorf("wanted errServerFailure, got: %v", err)
if err != nil {
t.Fatalf("wanted nil, got %v", err)
}
var parser dns.Parser
respHeader, err := parser.Start(resp)
if err != nil {
t.Fatalf("parser.Start() failed: %v", err)
}
if got, want := respHeader.RCode, dns.RCodeServerFailure; got != want {
t.Errorf("wanted %v, got %v", want, got)
}
}
// Test to ensure that if we have more than one resolver, and at least one of them
// returns a successful response, we propagate it.
func TestForwarderWithManyResolvers(t *testing.T) {
enableDebug(t)
const domain = "example.com."
request := makeTestRequest(t, domain)
tests := []struct {
name string
responses [][]byte // upstream responses
wantResponses [][]byte // we should receive one of these from the forwarder
}{
{
name: "Success",
responses: [][]byte{ // All upstream servers returned successful, but different, response.
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.2")),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.3")),
},
wantResponses: [][]byte{ // We may forward whichever response is received first.
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.2")),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.3")),
},
},
{
name: "ServFail",
responses: [][]byte{ // All upstream servers returned a SERVFAIL.
makeTestResponse(t, domain, dns.RCodeServerFailure),
makeTestResponse(t, domain, dns.RCodeServerFailure),
makeTestResponse(t, domain, dns.RCodeServerFailure),
},
wantResponses: [][]byte{
makeTestResponse(t, domain, dns.RCodeServerFailure),
},
},
{
name: "ServFail+Success",
responses: [][]byte{ // All upstream servers fail except for one.
makeTestResponse(t, domain, dns.RCodeServerFailure),
makeTestResponse(t, domain, dns.RCodeServerFailure),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
makeTestResponse(t, domain, dns.RCodeServerFailure),
},
wantResponses: [][]byte{ // We should forward the successful response.
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
},
},
{
name: "NXDomain",
responses: [][]byte{ // All upstream servers returned NXDOMAIN.
makeTestResponse(t, domain, dns.RCodeNameError),
makeTestResponse(t, domain, dns.RCodeNameError),
makeTestResponse(t, domain, dns.RCodeNameError),
},
wantResponses: [][]byte{
makeTestResponse(t, domain, dns.RCodeNameError),
},
},
{
name: "NXDomain+Success",
responses: [][]byte{ // All upstream servers returned NXDOMAIN except for one.
makeTestResponse(t, domain, dns.RCodeNameError),
makeTestResponse(t, domain, dns.RCodeNameError),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
},
wantResponses: [][]byte{ // However, only SERVFAIL are considered to be errors. Therefore, we may forward any response.
makeTestResponse(t, domain, dns.RCodeNameError),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
},
},
{
name: "Refused",
responses: [][]byte{ // All upstream servers return different failures.
makeTestResponse(t, domain, dns.RCodeRefused),
makeTestResponse(t, domain, dns.RCodeRefused),
makeTestResponse(t, domain, dns.RCodeRefused),
makeTestResponse(t, domain, dns.RCodeRefused),
makeTestResponse(t, domain, dns.RCodeRefused),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
},
wantResponses: [][]byte{ // Refused is not considered to be an error and can be forwarded.
makeTestResponse(t, domain, dns.RCodeRefused),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
},
},
{
name: "MixFail",
responses: [][]byte{ // All upstream servers return different failures.
makeTestResponse(t, domain, dns.RCodeServerFailure),
makeTestResponse(t, domain, dns.RCodeNameError),
makeTestResponse(t, domain, dns.RCodeRefused),
},
wantResponses: [][]byte{ // Both NXDomain and Refused can be forwarded.
makeTestResponse(t, domain, dns.RCodeNameError),
makeTestResponse(t, domain, dns.RCodeRefused),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ports := make([]uint16, len(tt.responses))
for i := range tt.responses {
ports[i] = runDNSServer(t, nil, tt.responses[i], func(isTCP bool, gotRequest []byte) {})
}
gotResponse, err := runTestQuery(t, request, nil, ports...)
if err != nil {
t.Fatalf("wanted nil, got %v", err)
}
responseOk := slices.ContainsFunc(tt.wantResponses, func(wantResponse []byte) bool {
return slices.Equal(gotResponse, wantResponse)
})
if !responseOk {
t.Errorf("invalid response\ngot: %+v\nwant: %+v", gotResponse, tt.wantResponses[0])
}
})
}
}
@ -713,7 +863,7 @@ func TestNXDOMAINIncludesQuestion(t *testing.T) {
port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) {
})
res, err := runTestQuery(t, port, request, nil)
res, err := runTestQuery(t, request, nil, port)
if err != nil {
t.Fatal(err)
}

View File

@ -321,15 +321,7 @@ func (r *Resolver) Query(ctx context.Context, bs []byte, family string, from net
defer cancel()
err = r.forwarder.forwardWithDestChan(ctx, packet{bs, family, from}, responses)
if err != nil {
select {
// Best effort: use any error response sent by forwardWithDestChan.
// This is present in some errors paths, such as when all upstream
// DNS servers replied with an error.
case resp := <-responses:
return resp.bs, err
default:
return nil, err
}
return nil, err
}
return (<-responses).bs, nil
}

View File

@ -1503,8 +1503,8 @@ func TestServfail(t *testing.T) {
r.SetConfig(cfg)
pkt, err := syncRespond(r, dnspacket("test.site.", dns.TypeA, noEdns))
if !errors.Is(err, errServerFailure) {
t.Errorf("err = %v, want %v", err, errServerFailure)
if err != nil {
t.Fatalf("err = %v, want nil", err)
}
wantPkt := []byte{