tsdns: don't forward transient DNS errors

When a DNS server claims to be unable or unwilling to handle a request,
instead of passing that refusal along to the client, just treat it as
any other error trying to connect to the DNS server. This prevents DNS
requests from failing based on if a server can respond with a transient
error before another server is able to give an actual response. DNS
requests only failing *sometimes* is really hard to find the cause of
(#1033).

Signed-off-by: Smitty <me@smitop.com>
This commit is contained in:
Smitty 2021-09-18 20:34:33 -04:00 committed by Adrian Dewhurst
parent 92215065eb
commit b382161fe5
2 changed files with 58 additions and 0 deletions

View File

@ -74,6 +74,15 @@ func getTxID(packet []byte) txid {
return txid(dnsid) return txid(dnsid)
} }
func getRCode(packet []byte) dns.RCode {
if len(packet) < headerBytes {
// treat invalid packets as a refusal
return dns.RCode(5)
}
// get bottom 4 bits of 3rd byte
return dns.RCode(packet[3] & 0x0F)
}
// clampEDNSSize attempts to limit the maximum EDNS response size. This is not // clampEDNSSize attempts to limit the maximum EDNS response size. This is not
// an exhaustive solution, instead only easy cases are currently handled in the // an exhaustive solution, instead only easy cases are currently handled in the
// interest of speed and reduced complexity. Only OPT records at the very end of // interest of speed and reduced complexity. Only OPT records at the very end of
@ -455,6 +464,12 @@ func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDe
if txid != fq.txid { if txid != fq.txid {
return nil, errors.New("txid doesn't match") return nil, errors.New("txid doesn't match")
} }
rcode := getRCode(out)
// don't forward transient errors back to the client when the server fails
if rcode == dns.RCodeServerFailure {
f.logf("recv: response code indicating server failure: %d", rcode)
return nil, errors.New("response code indicates server issue")
}
if truncated { if truncated {
const dnsFlagTruncated = 0x200 const dnsFlagTruncated = 0x200

View File

@ -12,6 +12,7 @@
"testing" "testing"
"time" "time"
dns "golang.org/x/net/dns/dnsmessage"
"tailscale.com/types/dnstype" "tailscale.com/types/dnstype"
) )
@ -97,3 +98,45 @@ func TestResolversWithDelays(t *testing.T) {
} }
} }
func TestGetRCode(t *testing.T) {
tests := []struct {
name string
packet []byte
want dns.RCode
}{
{
name: "empty",
packet: []byte{},
want: dns.RCode(5),
},
{
name: "too-short",
packet: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
want: dns.RCode(5),
},
{
name: "noerror",
packet: []byte{0xC4, 0xFE, 0x81, 0xA0, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01},
want: dns.RCode(0),
},
{
name: "refused",
packet: []byte{0xee, 0xa1, 0x81, 0x05, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
want: dns.RCode(5),
},
{
name: "nxdomain",
packet: []byte{0x34, 0xf4, 0x81, 0x83, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01},
want: dns.RCode(3),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := getRCode(tt.packet)
if got != tt.want {
t.Errorf("got %d; want %d", got, tt.want)
}
})
}
}