diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index 846ca3d5e..c00dea1ae 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -487,6 +487,10 @@ func (f *forwarder) sendDoH(ctx context.Context, urlBase string, c *http.Client, defer hres.Body.Close() if hres.StatusCode != 200 { metricDNSFwdDoHErrorStatus.Add(1) + if hres.StatusCode/100 == 5 { + // Translate 5xx HTTP server errors into SERVFAIL DNS responses. + return nil, fmt.Errorf("%w: %s", errServerFailure, hres.Status) + } return nil, errors.New(hres.Status) } if ct := hres.Header.Get("Content-Type"); ct != dohType { @@ -916,10 +920,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo metricDNSFwdDropBonjour.Add(1) res, err := nxDomainResponse(query) if err != nil { - f.logf("error parsing bonjour query: %v", err) - // Returning an error will cause an internal retry, there is - // nothing we can do if parsing failed. Just drop the packet. - return nil + return err } select { case <-ctx.Done(): @@ -951,10 +952,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo res, err := servfailResponse(query) if err != nil { - f.logf("building servfail response: %v", err) - // Returning an error will cause an internal retry, there is - // nothing we can do if parsing failed. Just drop the packet. - return nil + return err } select { case <-ctx.Done(): @@ -1053,6 +1051,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo if verboseDNSForward() { f.logf("forwarder response(%d, %v, %d) = %d, %v", fq.txid, typ, len(domain), len(res.bs), firstErr) } + return nil } } return firstErr diff --git a/net/dns/resolver/forwarder_test.go b/net/dns/resolver/forwarder_test.go index 09d810901..e341186ec 100644 --- a/net/dns/resolver/forwarder_test.go +++ b/net/dns/resolver/forwarder_test.go @@ -7,7 +7,6 @@ import ( "bytes" "context" "encoding/binary" - "errors" "flag" "fmt" "io" @@ -450,7 +449,7 @@ func makeLargeResponse(tb testing.TB, domain string) (request, response []byte) return } -func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) ([]byte, error) { +func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) ([]byte, error) { netMon, err := netmon.New(tb.Logf) if err != nil { tb.Fatal(err) @@ -464,8 +463,9 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa modify(fwd) } - rr := resolverAndDelay{ - name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)}, + resolvers := make([]resolverAndDelay, len(ports)) + for i, port := range ports { + resolvers[i].name = &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)} } rpkt := packet{ @@ -477,7 +477,7 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa rchan := make(chan packet, 1) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) tb.Cleanup(cancel) - err = fwd.forwardWithDestChan(ctx, rpkt, rchan, rr) + err = fwd.forwardWithDestChan(ctx, rpkt, rchan, resolvers...) select { case res := <-rchan: return res.bs, err @@ -486,8 +486,62 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa } } -func mustRunTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) []byte { - resp, err := runTestQuery(tb, port, request, modify) +// makeTestRequest returns a new TypeA request for the given domain. +func makeTestRequest(tb testing.TB, domain string) []byte { + tb.Helper() + name := dns.MustNewName(domain) + builder := dns.NewBuilder(nil, dns.Header{}) + builder.StartQuestions() + builder.Question(dns.Question{ + Name: name, + Type: dns.TypeA, + Class: dns.ClassINET, + }) + request, err := builder.Finish() + if err != nil { + tb.Fatal(err) + } + return request +} + +// makeTestResponse returns a new Type A response for the given domain, +// with the specified status code and zero or more addresses. +func makeTestResponse(tb testing.TB, domain string, code dns.RCode, addrs ...netip.Addr) []byte { + tb.Helper() + name := dns.MustNewName(domain) + builder := dns.NewBuilder(nil, dns.Header{ + Response: true, + Authoritative: true, + RCode: code, + }) + builder.StartQuestions() + q := dns.Question{ + Name: name, + Type: dns.TypeA, + Class: dns.ClassINET, + } + builder.Question(q) + if len(addrs) > 0 { + builder.StartAnswers() + for _, addr := range addrs { + builder.AResource(dns.ResourceHeader{ + Name: q.Name, + Class: q.Class, + TTL: 120, + }, dns.AResource{ + A: addr.As4(), + }) + } + } + response, err := builder.Finish() + if err != nil { + tb.Fatal(err) + } + return response +} + +func mustRunTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) []byte { + resp, err := runTestQuery(tb, request, modify, ports...) if err != nil { tb.Fatalf("error making request: %v", err) } @@ -516,7 +570,7 @@ func TestForwarderTCPFallback(t *testing.T) { } }) - resp := mustRunTestQuery(t, port, request, nil) + resp := mustRunTestQuery(t, request, nil, port) if !bytes.Equal(resp, largeResponse) { t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse) } @@ -554,7 +608,7 @@ func TestForwarderTCPFallbackTimeout(t *testing.T) { } }) - resp := mustRunTestQuery(t, port, request, nil) + resp := mustRunTestQuery(t, request, nil, port) if !bytes.Equal(resp, largeResponse) { t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse) } @@ -585,11 +639,11 @@ func TestForwarderTCPFallbackDisabled(t *testing.T) { } }) - resp := mustRunTestQuery(t, port, request, func(fwd *forwarder) { + resp := mustRunTestQuery(t, request, func(fwd *forwarder) { // Disable retries for this test. fwd.controlKnobs = &controlknobs.Knobs{} fwd.controlKnobs.DisableDNSForwarderTCPRetries.Store(true) - }) + }, port) wantResp := append([]byte(nil), largeResponse[:maxResponseBytes]...) @@ -613,41 +667,10 @@ func TestForwarderTCPFallbackError(t *testing.T) { const domain = "error-response.tailscale.com." // Our response is a SERVFAIL - response := func() []byte { - name := dns.MustNewName(domain) - - builder := dns.NewBuilder(nil, dns.Header{ - Response: true, - RCode: dns.RCodeServerFailure, - }) - builder.StartQuestions() - builder.Question(dns.Question{ - Name: name, - Type: dns.TypeA, - Class: dns.ClassINET, - }) - response, err := builder.Finish() - if err != nil { - t.Fatal(err) - } - return response - }() + response := makeTestResponse(t, domain, dns.RCodeServerFailure) // Our request is a single A query for the domain in the answer, above. - request := func() []byte { - builder := dns.NewBuilder(nil, dns.Header{}) - builder.StartQuestions() - builder.Question(dns.Question{ - Name: dns.MustNewName(domain), - Type: dns.TypeA, - Class: dns.ClassINET, - }) - request, err := builder.Finish() - if err != nil { - t.Fatal(err) - } - return request - }() + request := makeTestRequest(t, domain) var sawRequest atomic.Bool port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) { @@ -657,14 +680,141 @@ func TestForwarderTCPFallbackError(t *testing.T) { } }) - _, err := runTestQuery(t, port, request, nil) + resp, err := runTestQuery(t, request, nil, port) if !sawRequest.Load() { t.Error("did not see DNS request") } - if err == nil { - t.Error("wanted error, got nil") - } else if !errors.Is(err, errServerFailure) { - t.Errorf("wanted errServerFailure, got: %v", err) + if err != nil { + t.Fatalf("wanted nil, got %v", err) + } + var parser dns.Parser + respHeader, err := parser.Start(resp) + if err != nil { + t.Fatalf("parser.Start() failed: %v", err) + } + if got, want := respHeader.RCode, dns.RCodeServerFailure; got != want { + t.Errorf("wanted %v, got %v", want, got) + } +} + +// Test to ensure that if we have more than one resolver, and at least one of them +// returns a successful response, we propagate it. +func TestForwarderWithManyResolvers(t *testing.T) { + enableDebug(t) + + const domain = "example.com." + request := makeTestRequest(t, domain) + + tests := []struct { + name string + responses [][]byte // upstream responses + wantResponses [][]byte // we should receive one of these from the forwarder + }{ + { + name: "Success", + responses: [][]byte{ // All upstream servers returned successful, but different, response. + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.2")), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.3")), + }, + wantResponses: [][]byte{ // We may forward whichever response is received first. + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.2")), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.3")), + }, + }, + { + name: "ServFail", + responses: [][]byte{ // All upstream servers returned a SERVFAIL. + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeServerFailure), + }, + wantResponses: [][]byte{ + makeTestResponse(t, domain, dns.RCodeServerFailure), + }, + }, + { + name: "ServFail+Success", + responses: [][]byte{ // All upstream servers fail except for one. + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + makeTestResponse(t, domain, dns.RCodeServerFailure), + }, + wantResponses: [][]byte{ // We should forward the successful response. + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + }, + }, + { + name: "NXDomain", + responses: [][]byte{ // All upstream servers returned NXDOMAIN. + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeNameError), + }, + wantResponses: [][]byte{ + makeTestResponse(t, domain, dns.RCodeNameError), + }, + }, + { + name: "NXDomain+Success", + responses: [][]byte{ // All upstream servers returned NXDOMAIN except for one. + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + }, + wantResponses: [][]byte{ // However, only SERVFAIL are considered to be errors. Therefore, we may forward any response. + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + }, + }, + { + name: "Refused", + responses: [][]byte{ // All upstream servers return different failures. + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + }, + wantResponses: [][]byte{ // Refused is not considered to be an error and can be forwarded. + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + }, + }, + { + name: "MixFail", + responses: [][]byte{ // All upstream servers return different failures. + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeRefused), + }, + wantResponses: [][]byte{ // Both NXDomain and Refused can be forwarded. + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeRefused), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ports := make([]uint16, len(tt.responses)) + for i := range tt.responses { + ports[i] = runDNSServer(t, nil, tt.responses[i], func(isTCP bool, gotRequest []byte) {}) + } + gotResponse, err := runTestQuery(t, request, nil, ports...) + if err != nil { + t.Fatalf("wanted nil, got %v", err) + } + responseOk := slices.ContainsFunc(tt.wantResponses, func(wantResponse []byte) bool { + return slices.Equal(gotResponse, wantResponse) + }) + if !responseOk { + t.Errorf("invalid response\ngot: %+v\nwant: %+v", gotResponse, tt.wantResponses[0]) + } + }) } } @@ -713,7 +863,7 @@ func TestNXDOMAINIncludesQuestion(t *testing.T) { port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) { }) - res, err := runTestQuery(t, port, request, nil) + res, err := runTestQuery(t, request, nil, port) if err != nil { t.Fatal(err) } diff --git a/net/dns/resolver/tsdns.go b/net/dns/resolver/tsdns.go index d196ad4d6..43ba0acf1 100644 --- a/net/dns/resolver/tsdns.go +++ b/net/dns/resolver/tsdns.go @@ -321,15 +321,7 @@ func (r *Resolver) Query(ctx context.Context, bs []byte, family string, from net defer cancel() err = r.forwarder.forwardWithDestChan(ctx, packet{bs, family, from}, responses) if err != nil { - select { - // Best effort: use any error response sent by forwardWithDestChan. - // This is present in some errors paths, such as when all upstream - // DNS servers replied with an error. - case resp := <-responses: - return resp.bs, err - default: - return nil, err - } + return nil, err } return (<-responses).bs, nil } diff --git a/net/dns/resolver/tsdns_test.go b/net/dns/resolver/tsdns_test.go index e2c4750b5..d7b9fb360 100644 --- a/net/dns/resolver/tsdns_test.go +++ b/net/dns/resolver/tsdns_test.go @@ -1503,8 +1503,8 @@ func TestServfail(t *testing.T) { r.SetConfig(cfg) pkt, err := syncRespond(r, dnspacket("test.site.", dns.TypeA, noEdns)) - if !errors.Is(err, errServerFailure) { - t.Errorf("err = %v, want %v", err, errServerFailure) + if err != nil { + t.Fatalf("err = %v, want nil", err) } wantPkt := []byte{