2023-01-27 13:37:20 -08:00
|
|
|
// Copyright (c) Tailscale Inc & AUTHORS
|
|
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
2021-08-03 08:31:20 -07:00
|
|
|
|
|
|
|
package resolver
|
|
|
|
|
|
|
|
import (
|
2023-09-07 16:27:50 -04:00
|
|
|
"bytes"
|
|
|
|
"context"
|
|
|
|
"encoding/binary"
|
2023-10-03 16:26:38 -04:00
|
|
|
"errors"
|
2022-04-18 12:50:26 -07:00
|
|
|
"flag"
|
2021-08-03 08:31:20 -07:00
|
|
|
"fmt"
|
2023-09-07 16:27:50 -04:00
|
|
|
"io"
|
|
|
|
"net"
|
|
|
|
"net/netip"
|
|
|
|
"os"
|
2021-08-03 08:31:20 -07:00
|
|
|
"reflect"
|
|
|
|
"strings"
|
2023-09-07 16:27:50 -04:00
|
|
|
"sync"
|
|
|
|
"sync/atomic"
|
2021-08-03 08:31:20 -07:00
|
|
|
"testing"
|
|
|
|
"time"
|
|
|
|
|
2021-09-18 20:34:33 -04:00
|
|
|
dns "golang.org/x/net/dns/dnsmessage"
|
2023-10-03 16:26:38 -04:00
|
|
|
"tailscale.com/control/controlknobs"
|
2023-09-07 16:27:50 -04:00
|
|
|
"tailscale.com/envknob"
|
2024-07-29 13:48:46 -04:00
|
|
|
"tailscale.com/health"
|
2023-09-07 16:27:50 -04:00
|
|
|
"tailscale.com/net/netmon"
|
|
|
|
"tailscale.com/net/tsdial"
|
2021-08-03 06:56:31 -07:00
|
|
|
"tailscale.com/types/dnstype"
|
2021-08-03 08:31:20 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
func (rr resolverAndDelay) String() string {
|
2021-08-03 06:56:31 -07:00
|
|
|
return fmt.Sprintf("%v+%v", rr.name, rr.startDelay)
|
2021-08-03 08:31:20 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
func TestResolversWithDelays(t *testing.T) {
|
|
|
|
// query
|
2022-05-03 14:41:58 -07:00
|
|
|
q := func(ss ...string) (ipps []*dnstype.Resolver) {
|
2022-04-18 21:58:00 -07:00
|
|
|
for _, host := range ss {
|
2022-05-03 14:41:58 -07:00
|
|
|
ipps = append(ipps, &dnstype.Resolver{Addr: host})
|
2021-08-03 08:31:20 -07:00
|
|
|
}
|
|
|
|
return
|
|
|
|
}
|
|
|
|
// output
|
|
|
|
o := func(ss ...string) (rr []resolverAndDelay) {
|
|
|
|
for _, s := range ss {
|
|
|
|
var d time.Duration
|
2022-03-19 12:42:46 -07:00
|
|
|
s, durStr, hasPlus := strings.Cut(s, "+")
|
|
|
|
if hasPlus {
|
2021-08-03 08:31:20 -07:00
|
|
|
var err error
|
2022-03-19 12:42:46 -07:00
|
|
|
d, err = time.ParseDuration(durStr)
|
2021-08-03 08:31:20 -07:00
|
|
|
if err != nil {
|
|
|
|
panic(fmt.Sprintf("parsing duration in %q: %v", s, err))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
rr = append(rr, resolverAndDelay{
|
2022-05-03 14:41:58 -07:00
|
|
|
name: &dnstype.Resolver{Addr: s},
|
2021-08-03 08:31:20 -07:00
|
|
|
startDelay: d,
|
|
|
|
})
|
|
|
|
}
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
tests := []struct {
|
|
|
|
name string
|
2022-05-03 14:41:58 -07:00
|
|
|
in []*dnstype.Resolver
|
2021-08-03 08:31:20 -07:00
|
|
|
want []resolverAndDelay
|
|
|
|
}{
|
|
|
|
{
|
|
|
|
name: "unknown-no-delays",
|
2022-04-18 21:58:00 -07:00
|
|
|
in: q("1.2.3.4", "2.3.4.5"),
|
|
|
|
want: o("1.2.3.4", "2.3.4.5"),
|
2021-08-03 08:31:20 -07:00
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "google-all-ipv4",
|
2022-04-18 21:58:00 -07:00
|
|
|
in: q("8.8.8.8", "8.8.4.4"),
|
|
|
|
want: o("https://dns.google/dns-query", "8.8.8.8+0.5s", "8.8.4.4+0.7s"),
|
2021-08-03 08:31:20 -07:00
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "google-only-ipv6",
|
2022-04-18 21:58:00 -07:00
|
|
|
in: q("2001:4860:4860::8888", "2001:4860:4860::8844"),
|
|
|
|
want: o("https://dns.google/dns-query", "2001:4860:4860::8888+0.5s", "2001:4860:4860::8844+0.7s"),
|
2021-08-03 08:31:20 -07:00
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "google-all-four",
|
2022-04-18 21:58:00 -07:00
|
|
|
in: q("8.8.8.8", "8.8.4.4", "2001:4860:4860::8888", "2001:4860:4860::8844"),
|
|
|
|
want: o("https://dns.google/dns-query", "8.8.8.8+0.5s", "8.8.4.4+0.7s", "2001:4860:4860::8888+0.5s", "2001:4860:4860::8844+0.7s"),
|
2021-08-03 08:31:20 -07:00
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "quad9-one-v4-one-v6",
|
2022-04-18 21:58:00 -07:00
|
|
|
in: q("9.9.9.9", "2620:fe::fe"),
|
|
|
|
want: o("https://dns.quad9.net/dns-query", "9.9.9.9+0.5s", "2620:fe::fe+0.5s"),
|
2022-09-06 11:15:30 -07:00
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "nextdns-ipv6-expand",
|
|
|
|
in: q("2a07:a8c0::c3:a884"),
|
|
|
|
want: o("https://dns.nextdns.io/c3a884"),
|
2022-09-08 15:54:29 -07:00
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "nextdns-doh-input",
|
|
|
|
in: q("https://dns.nextdns.io/c3a884"),
|
|
|
|
want: o("https://dns.nextdns.io/c3a884"),
|
2021-08-03 08:31:20 -07:00
|
|
|
},
|
2023-06-22 20:47:36 -04:00
|
|
|
{
|
|
|
|
name: "controld-ipv6-expand",
|
|
|
|
in: q("2606:1a40:0:6:7b5b:5949:35ad:0"),
|
|
|
|
want: o("https://dns.controld.com/hyq3ipr2ct"),
|
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "controld-doh-input",
|
|
|
|
in: q("https://dns.controld.com/hyq3ipr2ct"),
|
|
|
|
want: o("https://dns.controld.com/hyq3ipr2ct"),
|
|
|
|
},
|
2021-08-03 08:31:20 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
for _, tt := range tests {
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
got := resolversWithDelays(tt.in)
|
|
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
|
|
t.Errorf("got %v; want %v", got, tt.want)
|
|
|
|
}
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
2021-09-18 20:34:33 -04:00
|
|
|
|
|
|
|
func TestGetRCode(t *testing.T) {
|
|
|
|
tests := []struct {
|
|
|
|
name string
|
|
|
|
packet []byte
|
|
|
|
want dns.RCode
|
|
|
|
}{
|
|
|
|
{
|
|
|
|
name: "empty",
|
|
|
|
packet: []byte{},
|
|
|
|
want: dns.RCode(5),
|
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "too-short",
|
|
|
|
packet: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
|
|
|
want: dns.RCode(5),
|
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "noerror",
|
|
|
|
packet: []byte{0xC4, 0xFE, 0x81, 0xA0, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01},
|
|
|
|
want: dns.RCode(0),
|
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "refused",
|
|
|
|
packet: []byte{0xee, 0xa1, 0x81, 0x05, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01},
|
|
|
|
want: dns.RCode(5),
|
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "nxdomain",
|
|
|
|
packet: []byte{0x34, 0xf4, 0x81, 0x83, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01},
|
|
|
|
want: dns.RCode(3),
|
|
|
|
},
|
|
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
got := getRCode(tt.packet)
|
|
|
|
if got != tt.want {
|
|
|
|
t.Errorf("got %d; want %d", got, tt.want)
|
|
|
|
}
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|
2021-10-14 22:39:11 -04:00
|
|
|
|
2022-04-18 12:50:26 -07:00
|
|
|
var testDNS = flag.Bool("test-dns", false, "run tests that require a working DNS server")
|
|
|
|
|
|
|
|
func TestGetKnownDoHClientForProvider(t *testing.T) {
|
|
|
|
var fwd forwarder
|
|
|
|
c, ok := fwd.getKnownDoHClientForProvider("https://dns.google/dns-query")
|
|
|
|
if !ok {
|
|
|
|
t.Fatal("not found")
|
|
|
|
}
|
|
|
|
if !*testDNS {
|
|
|
|
t.Skip("skipping without --test-dns")
|
|
|
|
}
|
|
|
|
res, err := c.Head("https://dns.google/")
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
defer res.Body.Close()
|
|
|
|
t.Logf("Got: %+v", res)
|
|
|
|
}
|
|
|
|
|
2021-12-18 19:03:38 -08:00
|
|
|
func BenchmarkNameFromQuery(b *testing.B) {
|
|
|
|
builder := dns.NewBuilder(nil, dns.Header{})
|
|
|
|
builder.StartQuestions()
|
|
|
|
builder.Question(dns.Question{
|
|
|
|
Name: dns.MustNewName("foo.example."),
|
|
|
|
Type: dns.TypeA,
|
|
|
|
Class: dns.ClassINET,
|
|
|
|
})
|
|
|
|
msg, err := builder.Finish()
|
|
|
|
if err != nil {
|
|
|
|
b.Fatal(err)
|
|
|
|
}
|
|
|
|
b.ResetTimer()
|
|
|
|
b.ReportAllocs()
|
2024-04-16 13:15:13 -07:00
|
|
|
for range b.N {
|
2024-08-08 15:41:08 -05:00
|
|
|
_, _, err := nameFromQuery(msg)
|
2021-12-18 19:03:38 -08:00
|
|
|
if err != nil {
|
|
|
|
b.Fatal(err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2022-07-08 14:22:50 +01:00
|
|
|
|
|
|
|
// Reproduces https://github.com/tailscale/tailscale/issues/2533
|
|
|
|
// Fixed by https://github.com/tailscale/tailscale/commit/f414a9cc01f3264912513d07c0244ff4f3e4ba54
|
|
|
|
//
|
|
|
|
// NOTE: fuzz tests act like unit tests when run without `-fuzz`
|
|
|
|
func FuzzClampEDNSSize(f *testing.F) {
|
|
|
|
// Empty DNS packet
|
|
|
|
f.Add([]byte{
|
|
|
|
// query id
|
|
|
|
0x12, 0x34,
|
|
|
|
// flags: standard query, recurse
|
|
|
|
0x01, 0x20,
|
|
|
|
// num questions
|
|
|
|
0x00, 0x00,
|
|
|
|
// num answers
|
|
|
|
0x00, 0x00,
|
|
|
|
// num authority RRs
|
|
|
|
0x00, 0x00,
|
|
|
|
// num additional RRs
|
|
|
|
0x00, 0x00,
|
|
|
|
})
|
|
|
|
|
|
|
|
// Empty OPT
|
|
|
|
f.Add([]byte{
|
|
|
|
// header
|
|
|
|
0xaf, 0x66, 0x01, 0x20, 0x00, 0x01, 0x00, 0x00,
|
|
|
|
0x00, 0x00, 0x00, 0x01,
|
|
|
|
// query
|
|
|
|
0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f,
|
|
|
|
0x6d, 0x00, 0x00, 0x01, 0x00, 0x01,
|
|
|
|
// OPT
|
|
|
|
0x00, // name: <root>
|
|
|
|
0x00, 0x29, // type: OPT
|
|
|
|
0x10, 0x00, // UDP payload size
|
|
|
|
0x00, // higher bits in extended RCODE
|
|
|
|
0x00, // EDNS0 version
|
|
|
|
0x80, 0x00, // "Z" field
|
|
|
|
0x00, 0x00, // data length
|
|
|
|
})
|
|
|
|
|
|
|
|
// Query for "google.com"
|
|
|
|
f.Add([]byte{
|
|
|
|
// header
|
|
|
|
0xaf, 0x66, 0x01, 0x20, 0x00, 0x01, 0x00, 0x00,
|
|
|
|
0x00, 0x00, 0x00, 0x01,
|
|
|
|
// query
|
|
|
|
0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f,
|
|
|
|
0x6d, 0x00, 0x00, 0x01, 0x00, 0x01,
|
|
|
|
// OPT
|
|
|
|
0x00, 0x00, 0x29, 0x10, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00,
|
|
|
|
0x0c, 0x00, 0x0a, 0x00, 0x08, 0x62, 0x18, 0x1a, 0xcb, 0x19,
|
|
|
|
0xd7, 0xee, 0x23,
|
|
|
|
})
|
|
|
|
|
|
|
|
f.Fuzz(func(t *testing.T, data []byte) {
|
|
|
|
clampEDNSSize(data, maxResponseBytes)
|
|
|
|
})
|
|
|
|
}
|
2023-09-07 16:27:50 -04:00
|
|
|
|
2023-10-03 16:26:38 -04:00
|
|
|
type testDNSServerOptions struct {
|
|
|
|
SkipUDP bool
|
|
|
|
SkipTCP bool
|
|
|
|
}
|
|
|
|
|
|
|
|
func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, onRequest func(bool, []byte)) (port uint16) {
|
|
|
|
if opts != nil && opts.SkipUDP && opts.SkipTCP {
|
|
|
|
tb.Fatal("cannot skip both UDP and TCP servers")
|
|
|
|
}
|
|
|
|
|
2023-09-07 16:27:50 -04:00
|
|
|
tcpResponse := make([]byte, len(response)+2)
|
|
|
|
binary.BigEndian.PutUint16(tcpResponse, uint16(len(response)))
|
|
|
|
copy(tcpResponse[2:], response)
|
|
|
|
|
|
|
|
// Repeatedly listen until we can get the same port.
|
|
|
|
const tries = 25
|
|
|
|
var (
|
|
|
|
tcpLn *net.TCPListener
|
|
|
|
udpLn *net.UDPConn
|
|
|
|
err error
|
|
|
|
)
|
|
|
|
for try := 0; try < tries; try++ {
|
|
|
|
if tcpLn != nil {
|
|
|
|
tcpLn.Close()
|
|
|
|
tcpLn = nil
|
|
|
|
}
|
|
|
|
|
|
|
|
tcpLn, err = net.ListenTCP("tcp4", &net.TCPAddr{
|
|
|
|
IP: net.IPv4(127, 0, 0, 1),
|
|
|
|
Port: 0, // Choose one
|
|
|
|
})
|
|
|
|
if err != nil {
|
|
|
|
tb.Fatal(err)
|
|
|
|
}
|
|
|
|
udpLn, err = net.ListenUDP("udp4", &net.UDPAddr{
|
|
|
|
IP: net.IPv4(127, 0, 0, 1),
|
|
|
|
Port: tcpLn.Addr().(*net.TCPAddr).Port,
|
|
|
|
})
|
|
|
|
if err == nil {
|
|
|
|
break
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if tcpLn == nil || udpLn == nil {
|
|
|
|
if tcpLn != nil {
|
|
|
|
tcpLn.Close()
|
|
|
|
}
|
|
|
|
if udpLn != nil {
|
|
|
|
udpLn.Close()
|
|
|
|
}
|
|
|
|
|
|
|
|
// Skip instead of being fatal to avoid flaking on extremely
|
|
|
|
// heavily-loaded CI systems.
|
|
|
|
tb.Skipf("failed to listen on same port for TCP/UDP after %d tries", tries)
|
|
|
|
}
|
|
|
|
|
|
|
|
port = uint16(tcpLn.Addr().(*net.TCPAddr).Port)
|
|
|
|
|
|
|
|
handleConn := func(conn net.Conn) {
|
|
|
|
defer conn.Close()
|
|
|
|
|
|
|
|
// Read the length header, then the buffer
|
|
|
|
var length uint16
|
|
|
|
if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
|
|
|
|
tb.Logf("error reading length header: %v", err)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
req := make([]byte, length)
|
|
|
|
n, err := io.ReadFull(conn, req)
|
|
|
|
if err != nil {
|
|
|
|
tb.Logf("error reading query: %v", err)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
req = req[:n]
|
|
|
|
onRequest(true, req)
|
|
|
|
|
|
|
|
// Write response
|
|
|
|
if _, err := conn.Write(tcpResponse); err != nil {
|
|
|
|
tb.Logf("error writing response: %v", err)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
var wg sync.WaitGroup
|
2023-10-03 16:26:38 -04:00
|
|
|
|
|
|
|
if opts == nil || !opts.SkipTCP {
|
|
|
|
wg.Add(1)
|
|
|
|
go func() {
|
|
|
|
defer wg.Done()
|
|
|
|
for {
|
|
|
|
conn, err := tcpLn.Accept()
|
|
|
|
if err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
go handleConn(conn)
|
2023-09-07 16:27:50 -04:00
|
|
|
}
|
2023-10-03 16:26:38 -04:00
|
|
|
}()
|
|
|
|
}
|
2023-09-07 16:27:50 -04:00
|
|
|
|
|
|
|
handleUDP := func(addr netip.AddrPort, req []byte) {
|
|
|
|
onRequest(false, req)
|
|
|
|
if _, err := udpLn.WriteToUDPAddrPort(response, addr); err != nil {
|
|
|
|
tb.Logf("error writing response: %v", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-10-03 16:26:38 -04:00
|
|
|
if opts == nil || !opts.SkipUDP {
|
|
|
|
wg.Add(1)
|
|
|
|
go func() {
|
|
|
|
defer wg.Done()
|
|
|
|
for {
|
|
|
|
buf := make([]byte, 65535)
|
|
|
|
n, addr, err := udpLn.ReadFromUDPAddrPort(buf)
|
|
|
|
if err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
buf = buf[:n]
|
|
|
|
go handleUDP(addr, buf)
|
2023-09-07 16:27:50 -04:00
|
|
|
}
|
2023-10-03 16:26:38 -04:00
|
|
|
}()
|
|
|
|
}
|
2023-09-07 16:27:50 -04:00
|
|
|
|
|
|
|
tb.Cleanup(func() {
|
|
|
|
tcpLn.Close()
|
|
|
|
udpLn.Close()
|
|
|
|
tb.Logf("waiting for listeners to finish...")
|
|
|
|
wg.Wait()
|
|
|
|
})
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2023-10-03 16:26:38 -04:00
|
|
|
func enableDebug(tb testing.TB) {
|
2023-09-07 16:27:50 -04:00
|
|
|
const debugKnob = "TS_DEBUG_DNS_FORWARD_SEND"
|
|
|
|
oldVal := os.Getenv(debugKnob)
|
|
|
|
envknob.Setenv(debugKnob, "true")
|
2023-10-03 16:26:38 -04:00
|
|
|
tb.Cleanup(func() { envknob.Setenv(debugKnob, oldVal) })
|
|
|
|
}
|
2023-09-07 16:27:50 -04:00
|
|
|
|
2023-10-03 16:26:38 -04:00
|
|
|
func makeLargeResponse(tb testing.TB, domain string) (request, response []byte) {
|
|
|
|
name := dns.MustNewName(domain)
|
2023-09-07 16:27:50 -04:00
|
|
|
|
2023-10-03 16:26:38 -04:00
|
|
|
builder := dns.NewBuilder(nil, dns.Header{})
|
|
|
|
builder.StartQuestions()
|
|
|
|
builder.Question(dns.Question{
|
|
|
|
Name: name,
|
|
|
|
Type: dns.TypeA,
|
|
|
|
Class: dns.ClassINET,
|
|
|
|
})
|
|
|
|
builder.StartAnswers()
|
2024-04-16 13:15:13 -07:00
|
|
|
for i := range 120 {
|
2023-10-03 16:26:38 -04:00
|
|
|
builder.AResource(dns.ResourceHeader{
|
2023-09-07 16:27:50 -04:00
|
|
|
Name: name,
|
|
|
|
Class: dns.ClassINET,
|
2023-10-03 16:26:38 -04:00
|
|
|
TTL: 300,
|
|
|
|
}, dns.AResource{
|
|
|
|
A: [4]byte{127, 0, 0, byte(i)},
|
2023-09-07 16:27:50 -04:00
|
|
|
})
|
2023-10-03 16:26:38 -04:00
|
|
|
}
|
2023-09-07 16:27:50 -04:00
|
|
|
|
2023-10-03 16:26:38 -04:00
|
|
|
var err error
|
|
|
|
response, err = builder.Finish()
|
|
|
|
if err != nil {
|
|
|
|
tb.Fatal(err)
|
|
|
|
}
|
|
|
|
if len(response) <= maxResponseBytes {
|
|
|
|
tb.Fatalf("got len(largeResponse)=%d, want > %d", len(response), maxResponseBytes)
|
2023-09-07 16:27:50 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
// Our request is a single A query for the domain in the answer, above.
|
2023-10-03 16:26:38 -04:00
|
|
|
builder = dns.NewBuilder(nil, dns.Header{})
|
|
|
|
builder.StartQuestions()
|
|
|
|
builder.Question(dns.Question{
|
|
|
|
Name: dns.MustNewName(domain),
|
|
|
|
Type: dns.TypeA,
|
|
|
|
Class: dns.ClassINET,
|
2023-09-07 16:27:50 -04:00
|
|
|
})
|
2023-10-03 16:26:38 -04:00
|
|
|
request, err = builder.Finish()
|
|
|
|
if err != nil {
|
|
|
|
tb.Fatal(err)
|
|
|
|
}
|
2023-09-07 16:27:50 -04:00
|
|
|
|
2023-10-03 16:26:38 -04:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) ([]byte, error) {
|
|
|
|
netMon, err := netmon.New(tb.Logf)
|
2023-09-07 16:27:50 -04:00
|
|
|
if err != nil {
|
2023-10-03 16:26:38 -04:00
|
|
|
tb.Fatal(err)
|
2023-09-07 16:27:50 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
var dialer tsdial.Dialer
|
|
|
|
dialer.SetNetMon(netMon)
|
|
|
|
|
2024-07-29 13:48:46 -04:00
|
|
|
fwd := newForwarder(tb.Logf, netMon, nil, &dialer, new(health.Tracker), nil)
|
2023-10-03 16:26:38 -04:00
|
|
|
if modify != nil {
|
|
|
|
modify(fwd)
|
|
|
|
}
|
2023-09-07 16:27:50 -04:00
|
|
|
|
|
|
|
fq := &forwardQuery{
|
|
|
|
txid: getTxID(request),
|
|
|
|
packet: request,
|
|
|
|
closeOnCtxDone: new(closePool),
|
2023-10-03 21:24:53 -04:00
|
|
|
family: "tcp",
|
2023-09-07 16:27:50 -04:00
|
|
|
}
|
|
|
|
defer fq.closeOnCtxDone.Close()
|
|
|
|
|
|
|
|
rr := resolverAndDelay{
|
|
|
|
name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)},
|
|
|
|
}
|
|
|
|
|
2023-10-03 16:26:38 -04:00
|
|
|
return fwd.send(context.Background(), fq, rr)
|
|
|
|
}
|
|
|
|
|
|
|
|
func mustRunTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) []byte {
|
|
|
|
resp, err := runTestQuery(tb, port, request, modify)
|
2023-09-07 16:27:50 -04:00
|
|
|
if err != nil {
|
2023-10-03 16:26:38 -04:00
|
|
|
tb.Fatalf("error making request: %v", err)
|
2023-09-07 16:27:50 -04:00
|
|
|
}
|
2023-10-03 16:26:38 -04:00
|
|
|
return resp
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestForwarderTCPFallback(t *testing.T) {
|
|
|
|
enableDebug(t)
|
|
|
|
|
|
|
|
const domain = "large-dns-response.tailscale.com."
|
|
|
|
|
|
|
|
// Make a response that's very large, containing a bunch of localhost addresses.
|
|
|
|
request, largeResponse := makeLargeResponse(t, domain)
|
|
|
|
|
2023-10-03 21:24:53 -04:00
|
|
|
var sawTCPRequest atomic.Bool
|
2023-10-03 16:26:38 -04:00
|
|
|
port := runDNSServer(t, nil, largeResponse, func(isTCP bool, gotRequest []byte) {
|
|
|
|
if isTCP {
|
|
|
|
t.Logf("saw TCP request")
|
|
|
|
sawTCPRequest.Store(true)
|
|
|
|
} else {
|
|
|
|
t.Logf("saw UDP request")
|
|
|
|
}
|
|
|
|
|
|
|
|
if !bytes.Equal(request, gotRequest) {
|
|
|
|
t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request)
|
|
|
|
}
|
|
|
|
})
|
|
|
|
|
|
|
|
resp := mustRunTestQuery(t, port, request, nil)
|
2023-09-07 16:27:50 -04:00
|
|
|
if !bytes.Equal(resp, largeResponse) {
|
|
|
|
t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
|
|
|
|
}
|
|
|
|
if !sawTCPRequest.Load() {
|
|
|
|
t.Errorf("DNS server never saw TCP request")
|
|
|
|
}
|
2023-10-03 21:24:53 -04:00
|
|
|
|
|
|
|
// NOTE: can't assert that we see a UDP request here since we might
|
|
|
|
// race and run the TCP query first. We test the UDP codepath in
|
|
|
|
// TestForwarderTCPFallbackDisabled below, though.
|
2023-09-07 16:27:50 -04:00
|
|
|
}
|
2023-10-03 16:26:38 -04:00
|
|
|
|
|
|
|
// Test to ensure that if the UDP listener is unresponsive, we always make a
|
|
|
|
// TCP request even if we never get a response.
|
|
|
|
func TestForwarderTCPFallbackTimeout(t *testing.T) {
|
|
|
|
enableDebug(t)
|
|
|
|
|
|
|
|
const domain = "large-dns-response.tailscale.com."
|
|
|
|
|
|
|
|
// Make a response that's very large, containing a bunch of localhost addresses.
|
|
|
|
request, largeResponse := makeLargeResponse(t, domain)
|
|
|
|
|
|
|
|
var sawTCPRequest atomic.Bool
|
|
|
|
opts := &testDNSServerOptions{SkipUDP: true}
|
|
|
|
port := runDNSServer(t, opts, largeResponse, func(isTCP bool, gotRequest []byte) {
|
|
|
|
if isTCP {
|
|
|
|
t.Logf("saw TCP request")
|
|
|
|
sawTCPRequest.Store(true)
|
|
|
|
} else {
|
|
|
|
t.Error("saw unexpected UDP request")
|
|
|
|
}
|
|
|
|
|
|
|
|
if !bytes.Equal(request, gotRequest) {
|
|
|
|
t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request)
|
|
|
|
}
|
|
|
|
})
|
|
|
|
|
|
|
|
resp := mustRunTestQuery(t, port, request, nil)
|
|
|
|
if !bytes.Equal(resp, largeResponse) {
|
|
|
|
t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
|
|
|
|
}
|
|
|
|
if !sawTCPRequest.Load() {
|
|
|
|
t.Errorf("DNS server never saw TCP request")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestForwarderTCPFallbackDisabled(t *testing.T) {
|
|
|
|
enableDebug(t)
|
|
|
|
|
|
|
|
const domain = "large-dns-response.tailscale.com."
|
|
|
|
|
|
|
|
// Make a response that's very large, containing a bunch of localhost addresses.
|
|
|
|
request, largeResponse := makeLargeResponse(t, domain)
|
|
|
|
|
|
|
|
var sawUDPRequest atomic.Bool
|
|
|
|
port := runDNSServer(t, nil, largeResponse, func(isTCP bool, gotRequest []byte) {
|
|
|
|
if isTCP {
|
|
|
|
t.Error("saw unexpected TCP request")
|
|
|
|
} else {
|
|
|
|
t.Logf("saw UDP request")
|
|
|
|
sawUDPRequest.Store(true)
|
|
|
|
}
|
|
|
|
|
|
|
|
if !bytes.Equal(request, gotRequest) {
|
|
|
|
t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request)
|
|
|
|
}
|
|
|
|
})
|
|
|
|
|
|
|
|
resp := mustRunTestQuery(t, port, request, func(fwd *forwarder) {
|
|
|
|
// Disable retries for this test.
|
|
|
|
fwd.controlKnobs = &controlknobs.Knobs{}
|
|
|
|
fwd.controlKnobs.DisableDNSForwarderTCPRetries.Store(true)
|
|
|
|
})
|
|
|
|
|
|
|
|
wantResp := append([]byte(nil), largeResponse[:maxResponseBytes]...)
|
|
|
|
|
|
|
|
// Set the truncated flag on the expected response, since that's what we expect.
|
|
|
|
flags := binary.BigEndian.Uint16(wantResp[2:4])
|
|
|
|
flags |= dnsFlagTruncated
|
|
|
|
binary.BigEndian.PutUint16(wantResp[2:4], flags)
|
|
|
|
|
|
|
|
if !bytes.Equal(resp, wantResp) {
|
|
|
|
t.Errorf("invalid response\ngot (%d): %+v\nwant (%d): %+v", len(resp), resp, len(wantResp), wantResp)
|
|
|
|
}
|
|
|
|
if !sawUDPRequest.Load() {
|
|
|
|
t.Errorf("DNS server never saw UDP request")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Test to ensure that we propagate DNS errors
|
|
|
|
func TestForwarderTCPFallbackError(t *testing.T) {
|
|
|
|
enableDebug(t)
|
|
|
|
|
|
|
|
const domain = "error-response.tailscale.com."
|
|
|
|
|
|
|
|
// Our response is a SERVFAIL
|
|
|
|
response := func() []byte {
|
|
|
|
name := dns.MustNewName(domain)
|
|
|
|
|
|
|
|
builder := dns.NewBuilder(nil, dns.Header{
|
|
|
|
RCode: dns.RCodeServerFailure,
|
|
|
|
})
|
|
|
|
builder.StartQuestions()
|
|
|
|
builder.Question(dns.Question{
|
|
|
|
Name: name,
|
|
|
|
Type: dns.TypeA,
|
|
|
|
Class: dns.ClassINET,
|
|
|
|
})
|
|
|
|
response, err := builder.Finish()
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
return response
|
|
|
|
}()
|
|
|
|
|
|
|
|
// Our request is a single A query for the domain in the answer, above.
|
|
|
|
request := func() []byte {
|
|
|
|
builder := dns.NewBuilder(nil, dns.Header{})
|
|
|
|
builder.StartQuestions()
|
|
|
|
builder.Question(dns.Question{
|
|
|
|
Name: dns.MustNewName(domain),
|
|
|
|
Type: dns.TypeA,
|
|
|
|
Class: dns.ClassINET,
|
|
|
|
})
|
|
|
|
request, err := builder.Finish()
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
return request
|
|
|
|
}()
|
|
|
|
|
|
|
|
var sawRequest atomic.Bool
|
|
|
|
port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) {
|
|
|
|
sawRequest.Store(true)
|
|
|
|
if !bytes.Equal(request, gotRequest) {
|
|
|
|
t.Errorf("invalid request\ngot: %+v\nwant: %+v", gotRequest, request)
|
|
|
|
}
|
|
|
|
})
|
|
|
|
|
|
|
|
_, err := runTestQuery(t, port, request, nil)
|
|
|
|
if !sawRequest.Load() {
|
|
|
|
t.Error("did not see DNS request")
|
|
|
|
}
|
|
|
|
if err == nil {
|
|
|
|
t.Error("wanted error, got nil")
|
|
|
|
} else if !errors.Is(err, errServerFailure) {
|
|
|
|
t.Errorf("wanted errServerFailure, got: %v", err)
|
|
|
|
}
|
|
|
|
}
|