mirror of
https://github.com/tailscale/tailscale.git
synced 2025-08-12 05:37:32 +00:00
Move wgengine/tsdns to net/dns.
Straight move+fixup, no other changes. In prep for merging with wgengine/router/dns. Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:

committed by
Dave Anderson

parent
81143b6d9a
commit
8432999835
474
net/dns/forwarder.go
Normal file
474
net/dns/forwarder.go
Normal file
@@ -0,0 +1,474 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/logtail/backoff"
|
||||
"tailscale.com/net/netns"
|
||||
"tailscale.com/types/logger"
|
||||
)
|
||||
|
||||
// headerBytes is the number of bytes in a DNS message header.
|
||||
const headerBytes = 12
|
||||
|
||||
// connCount is the number of UDP connections to use for forwarding.
|
||||
const connCount = 32
|
||||
|
||||
const (
|
||||
// cleanupInterval is the interval between purged of timed-out entries from txMap.
|
||||
cleanupInterval = 30 * time.Second
|
||||
// responseTimeout is the maximal amount of time to wait for a DNS response.
|
||||
responseTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
var errNoUpstreams = errors.New("upstream nameservers not set")
|
||||
|
||||
var aLongTimeAgo = time.Unix(0, 1)
|
||||
|
||||
type forwardingRecord struct {
|
||||
src netaddr.IPPort
|
||||
createdAt time.Time
|
||||
}
|
||||
|
||||
// txid identifies a DNS transaction.
|
||||
//
|
||||
// As the standard DNS Request ID is only 16 bits, we extend it:
|
||||
// the lower 32 bits are the zero-extended bits of the DNS Request ID;
|
||||
// the upper 32 bits are the CRC32 checksum of the first question in the request.
|
||||
// This makes probability of txid collision negligible.
|
||||
type txid uint64
|
||||
|
||||
// getTxID computes the txid of the given DNS packet.
|
||||
func getTxID(packet []byte) txid {
|
||||
if len(packet) < headerBytes {
|
||||
return 0
|
||||
}
|
||||
|
||||
dnsid := binary.BigEndian.Uint16(packet[0:2])
|
||||
qcount := binary.BigEndian.Uint16(packet[4:6])
|
||||
if qcount == 0 {
|
||||
return txid(dnsid)
|
||||
}
|
||||
|
||||
offset := headerBytes
|
||||
for i := uint16(0); i < qcount; i++ {
|
||||
// Note: this relies on the fact that names are not compressed in questions,
|
||||
// so they are guaranteed to end with a NUL byte.
|
||||
//
|
||||
// Justification:
|
||||
// RFC 1035 doesn't seem to explicitly prohibit compressing names in questions,
|
||||
// but this is exceedingly unlikely to be done in practice. A DNS request
|
||||
// with multiple questions is ill-defined (which questions do the header flags apply to?)
|
||||
// and a single question would have to contain a pointer to an *answer*,
|
||||
// which would be excessively smart, pointless (an answer can just as well refer to the question)
|
||||
// and perhaps even prohibited: a draft RFC (draft-ietf-dnsind-local-compression-05) states:
|
||||
//
|
||||
// > It is important that these pointers always point backwards.
|
||||
//
|
||||
// This is said in summarizing RFC 1035, although that phrase does not appear in the original RFC.
|
||||
// Additionally, (https://cr.yp.to/djbdns/notes.html) states:
|
||||
//
|
||||
// > The precise rule is that a name can be compressed if it is a response owner name,
|
||||
// > the name in NS data, the name in CNAME data, the name in PTR data, the name in MX data,
|
||||
// > or one of the names in SOA data.
|
||||
namebytes := bytes.IndexByte(packet[offset:], 0)
|
||||
// ... | name | NUL | type | class
|
||||
// ?? 1 2 2
|
||||
offset = offset + namebytes + 5
|
||||
if len(packet) < offset {
|
||||
// Corrupt packet; don't crash.
|
||||
return txid(dnsid)
|
||||
}
|
||||
}
|
||||
|
||||
hash := crc32.ChecksumIEEE(packet[headerBytes:offset])
|
||||
return (txid(hash) << 32) | txid(dnsid)
|
||||
}
|
||||
|
||||
// forwarder forwards DNS packets to a number of upstream nameservers.
|
||||
type forwarder struct {
|
||||
logf logger.Logf
|
||||
|
||||
// responses is a channel by which responses are returned.
|
||||
responses chan Packet
|
||||
// closed signals all goroutines to stop.
|
||||
closed chan struct{}
|
||||
// wg signals when all goroutines have stopped.
|
||||
wg sync.WaitGroup
|
||||
|
||||
// conns are the UDP connections used for forwarding.
|
||||
// A random one is selected for each request, regardless of the target upstream.
|
||||
conns []*fwdConn
|
||||
|
||||
mu sync.Mutex
|
||||
// upstreams are the nameserver addresses that should be used for forwarding.
|
||||
upstreams []net.Addr
|
||||
// txMap maps DNS txids to active forwarding records.
|
||||
txMap map[txid]forwardingRecord
|
||||
}
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func newForwarder(logf logger.Logf, responses chan Packet) *forwarder {
|
||||
return &forwarder{
|
||||
logf: logger.WithPrefix(logf, "forward: "),
|
||||
responses: responses,
|
||||
closed: make(chan struct{}),
|
||||
conns: make([]*fwdConn, connCount),
|
||||
txMap: make(map[txid]forwardingRecord),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *forwarder) Start() error {
|
||||
f.wg.Add(connCount + 1)
|
||||
for idx := range f.conns {
|
||||
f.conns[idx] = newFwdConn(f.logf, idx)
|
||||
go f.recv(f.conns[idx])
|
||||
}
|
||||
go f.cleanMap()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *forwarder) Close() {
|
||||
select {
|
||||
case <-f.closed:
|
||||
return
|
||||
default:
|
||||
// continue
|
||||
}
|
||||
close(f.closed)
|
||||
|
||||
for _, conn := range f.conns {
|
||||
conn.close()
|
||||
}
|
||||
|
||||
f.wg.Wait()
|
||||
}
|
||||
|
||||
func (f *forwarder) rebindFromNetworkChange() {
|
||||
for _, c := range f.conns {
|
||||
c.mu.Lock()
|
||||
c.reconnectLocked()
|
||||
c.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (f *forwarder) setUpstreams(upstreams []net.Addr) {
|
||||
f.mu.Lock()
|
||||
f.upstreams = upstreams
|
||||
f.mu.Unlock()
|
||||
}
|
||||
|
||||
// send sends packet to dst. It is best effort.
|
||||
func (f *forwarder) send(packet []byte, dst net.Addr) {
|
||||
connIdx := rand.Intn(connCount)
|
||||
conn := f.conns[connIdx]
|
||||
conn.send(packet, dst)
|
||||
}
|
||||
|
||||
func (f *forwarder) recv(conn *fwdConn) {
|
||||
defer f.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-f.closed:
|
||||
return
|
||||
default:
|
||||
}
|
||||
out := make([]byte, maxResponseBytes)
|
||||
n := conn.read(out)
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
if n < headerBytes {
|
||||
f.logf("recv: packet too small (%d bytes)", n)
|
||||
}
|
||||
|
||||
out = out[:n]
|
||||
txid := getTxID(out)
|
||||
|
||||
f.mu.Lock()
|
||||
|
||||
record, found := f.txMap[txid]
|
||||
// At most one nameserver will return a response:
|
||||
// the first one to do so will delete txid from the map.
|
||||
if !found {
|
||||
f.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
delete(f.txMap, txid)
|
||||
|
||||
f.mu.Unlock()
|
||||
|
||||
packet := Packet{
|
||||
Payload: out,
|
||||
Addr: record.src,
|
||||
}
|
||||
select {
|
||||
case <-f.closed:
|
||||
return
|
||||
case f.responses <- packet:
|
||||
// continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanMap periodically deletes timed-out forwarding records from f.txMap to bound growth.
|
||||
func (f *forwarder) cleanMap() {
|
||||
defer f.wg.Done()
|
||||
|
||||
t := time.NewTicker(cleanupInterval)
|
||||
defer t.Stop()
|
||||
|
||||
var now time.Time
|
||||
for {
|
||||
select {
|
||||
case <-f.closed:
|
||||
return
|
||||
case now = <-t.C:
|
||||
// continue
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
for k, v := range f.txMap {
|
||||
if now.Sub(v.createdAt) > responseTimeout {
|
||||
delete(f.txMap, k)
|
||||
}
|
||||
}
|
||||
f.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// forward forwards the query to all upstream nameservers and returns the first response.
|
||||
func (f *forwarder) forward(query Packet) error {
|
||||
txid := getTxID(query.Payload)
|
||||
|
||||
f.mu.Lock()
|
||||
|
||||
upstreams := f.upstreams
|
||||
if len(upstreams) == 0 {
|
||||
f.mu.Unlock()
|
||||
return errNoUpstreams
|
||||
}
|
||||
f.txMap[txid] = forwardingRecord{
|
||||
src: query.Addr,
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
|
||||
f.mu.Unlock()
|
||||
|
||||
for _, upstream := range upstreams {
|
||||
f.send(query.Payload, upstream)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// A fwdConn manages a single connection used to forward DNS requests.
|
||||
// Net link changes can cause a *net.UDPConn to become permanently unusable, particularly on macOS.
|
||||
// fwdConn detects such situations and transparently creates new connections.
|
||||
type fwdConn struct {
|
||||
// logf allows a fwdConn to log.
|
||||
logf logger.Logf
|
||||
|
||||
// wg tracks the number of outstanding conn.Read and conn.Write calls.
|
||||
wg sync.WaitGroup
|
||||
// change allows calls to read to block until a the network connection has been replaced.
|
||||
change *sync.Cond
|
||||
|
||||
// mu protects fields that follow it; it is also change's Locker.
|
||||
mu sync.Mutex
|
||||
// closed tracks whether fwdConn has been permanently closed.
|
||||
closed bool
|
||||
// conn is the current active connection.
|
||||
conn net.PacketConn
|
||||
}
|
||||
|
||||
func newFwdConn(logf logger.Logf, idx int) *fwdConn {
|
||||
c := new(fwdConn)
|
||||
c.logf = logger.WithPrefix(logf, fmt.Sprintf("fwdConn %d: ", idx))
|
||||
c.change = sync.NewCond(&c.mu)
|
||||
// c.conn is created lazily in send
|
||||
return c
|
||||
}
|
||||
|
||||
// send sends packet to dst using c's connection.
|
||||
// It is best effort. It is UDP, after all. Failures are logged.
|
||||
func (c *fwdConn) send(packet []byte, dst net.Addr) {
|
||||
var b *backoff.Backoff // lazily initialized, since it is not needed in the common case
|
||||
backOff := func(err error) {
|
||||
if b == nil {
|
||||
b = backoff.NewBackoff("dns-fwdConn-send", c.logf, 30*time.Second)
|
||||
}
|
||||
b.BackOff(context.Background(), err)
|
||||
}
|
||||
|
||||
for {
|
||||
// Gather the current connection.
|
||||
// We can't hold the lock while we call WriteTo.
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
closed := c.closed
|
||||
if closed {
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
if conn == nil {
|
||||
c.reconnectLocked()
|
||||
c.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
c.wg.Add(1)
|
||||
_, err := conn.WriteTo(packet, dst)
|
||||
c.wg.Done()
|
||||
if err == nil {
|
||||
// Success
|
||||
return
|
||||
}
|
||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
// We intentionally closed this connection.
|
||||
// It has been replaced by a new connection. Try again.
|
||||
continue
|
||||
}
|
||||
// Something else went wrong.
|
||||
// We have three choices here: try again, give up, or create a new connection.
|
||||
var opErr *net.OpError
|
||||
if !errors.As(err, &opErr) {
|
||||
// Weird. All errors from the net package should be *net.OpError. Bail.
|
||||
c.logf("send: non-*net.OpErr %v (%T)", err, err)
|
||||
return
|
||||
}
|
||||
if opErr.Temporary() || opErr.Timeout() {
|
||||
// I doubt that either of these can happen (this is UDP),
|
||||
// but go ahead and try again.
|
||||
backOff(err)
|
||||
continue
|
||||
}
|
||||
if networkIsDown(err) {
|
||||
// Fail.
|
||||
c.logf("send: network is down")
|
||||
return
|
||||
}
|
||||
if networkIsUnreachable(err) {
|
||||
// This can be caused by a link change.
|
||||
// Replace the existing connection with a new one.
|
||||
c.mu.Lock()
|
||||
// It's possible that multiple senders discovered simultaneously
|
||||
// that the network is unreachable. Avoid reconnecting multiple times:
|
||||
// Only reconnect if the current connection is the one that we
|
||||
// discovered to be problematic.
|
||||
if c.conn == conn {
|
||||
backOff(err)
|
||||
c.reconnectLocked()
|
||||
}
|
||||
c.mu.Unlock()
|
||||
// Try again with our new network connection.
|
||||
continue
|
||||
}
|
||||
// Unrecognized error. Fail.
|
||||
c.logf("send: unrecognized error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// read waits for a response from c's connection.
|
||||
// It returns the number of bytes read, which may be 0
|
||||
// in case of an error or a closed connection.
|
||||
func (c *fwdConn) read(out []byte) int {
|
||||
for {
|
||||
// Gather the current connection.
|
||||
// We can't hold the lock while we call ReadFrom.
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
closed := c.closed
|
||||
if closed {
|
||||
c.mu.Unlock()
|
||||
return 0
|
||||
}
|
||||
if conn == nil {
|
||||
// There is no current connection.
|
||||
// Wait for the connection to change, then try again.
|
||||
c.change.Wait()
|
||||
c.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
c.wg.Add(1)
|
||||
n, _, err := conn.ReadFrom(out)
|
||||
c.wg.Done()
|
||||
if err == nil {
|
||||
// Success.
|
||||
return n
|
||||
}
|
||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
// We intentionally closed this connection.
|
||||
// It has been replaced by a new connection. Try again.
|
||||
continue
|
||||
}
|
||||
|
||||
c.logf("read: unrecognized error: %v", err)
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// reconnectLocked replaces the current connection with a new one.
|
||||
// c.mu must be locked.
|
||||
func (c *fwdConn) reconnectLocked() {
|
||||
c.closeConnLocked()
|
||||
// Make a new connection.
|
||||
conn, err := netns.Listener().ListenPacket(context.Background(), "udp", "")
|
||||
if err != nil {
|
||||
c.logf("ListenPacket failed: %v", err)
|
||||
} else {
|
||||
c.conn = conn
|
||||
}
|
||||
// Broadcast that a new connection is available.
|
||||
c.change.Broadcast()
|
||||
}
|
||||
|
||||
// closeCurrentConn closes the current connection.
|
||||
// c.mu must be locked.
|
||||
func (c *fwdConn) closeConnLocked() {
|
||||
if c.conn == nil {
|
||||
return
|
||||
}
|
||||
// Unblock all readers/writers, wait for them, close the connection.
|
||||
c.conn.SetDeadline(aLongTimeAgo)
|
||||
c.wg.Wait()
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
|
||||
// close permanently closes c.
|
||||
func (c *fwdConn) close() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.closed {
|
||||
return
|
||||
}
|
||||
c.closed = true
|
||||
c.closeConnLocked()
|
||||
// Unblock any remaining readers.
|
||||
c.change.Broadcast()
|
||||
}
|
160
net/dns/map.go
Normal file
160
net/dns/map.go
Normal file
@@ -0,0 +1,160 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
// Map is all the data Resolver needs to resolve DNS queries within the Tailscale network.
|
||||
type Map struct {
|
||||
// nameToIP is a mapping of Tailscale domain names to their IP addresses.
|
||||
// For example, monitoring.tailscale.us -> 100.64.0.1.
|
||||
nameToIP map[string]netaddr.IP
|
||||
// ipToName is the inverse of nameToIP.
|
||||
ipToName map[netaddr.IP]string
|
||||
// names are the keys of nameToIP in sorted order.
|
||||
names []string
|
||||
// rootDomains are the domains whose subdomains should always
|
||||
// be resolved locally to prevent leakage of sensitive names.
|
||||
rootDomains []string // e.g. "user.provider.beta.tailscale.net."
|
||||
}
|
||||
|
||||
// NewMap returns a new Map with name to address mapping given by nameToIP.
|
||||
//
|
||||
// rootDomains are the domains whose subdomains should always be
|
||||
// resolved locally to prevent leakage of sensitive names. They should
|
||||
// end in a period ("user-foo.tailscale.net.").
|
||||
func NewMap(initNameToIP map[string]netaddr.IP, rootDomains []string) *Map {
|
||||
// TODO(dmytro): we have to allocate names and ipToName, but nameToIP can be avoided.
|
||||
// It is here because control sends us names not in canonical form. Change this.
|
||||
names := make([]string, 0, len(initNameToIP))
|
||||
nameToIP := make(map[string]netaddr.IP, len(initNameToIP))
|
||||
ipToName := make(map[netaddr.IP]string, len(initNameToIP))
|
||||
|
||||
for name, ip := range initNameToIP {
|
||||
if len(name) == 0 {
|
||||
// Nothing useful can be done with empty names.
|
||||
continue
|
||||
}
|
||||
if name[len(name)-1] != '.' {
|
||||
name += "."
|
||||
}
|
||||
names = append(names, name)
|
||||
nameToIP[name] = ip
|
||||
ipToName[ip] = name
|
||||
}
|
||||
sort.Strings(names)
|
||||
|
||||
return &Map{
|
||||
nameToIP: nameToIP,
|
||||
ipToName: ipToName,
|
||||
names: names,
|
||||
|
||||
rootDomains: rootDomains,
|
||||
}
|
||||
}
|
||||
|
||||
func printSingleNameIP(buf *strings.Builder, name string, ip netaddr.IP) {
|
||||
buf.WriteString(name)
|
||||
buf.WriteByte('\t')
|
||||
buf.WriteString(ip.String())
|
||||
buf.WriteByte('\n')
|
||||
}
|
||||
|
||||
func (m *Map) Pretty() string {
|
||||
buf := new(strings.Builder)
|
||||
for _, name := range m.names {
|
||||
printSingleNameIP(buf, name, m.nameToIP[name])
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (m *Map) PrettyDiffFrom(old *Map) string {
|
||||
var (
|
||||
oldNameToIP map[string]netaddr.IP
|
||||
newNameToIP map[string]netaddr.IP
|
||||
oldNames []string
|
||||
newNames []string
|
||||
)
|
||||
if old != nil {
|
||||
oldNameToIP = old.nameToIP
|
||||
oldNames = old.names
|
||||
}
|
||||
if m != nil {
|
||||
newNameToIP = m.nameToIP
|
||||
newNames = m.names
|
||||
}
|
||||
|
||||
buf := new(strings.Builder)
|
||||
space := func() bool {
|
||||
return buf.Len() < (1 << 10)
|
||||
}
|
||||
|
||||
for len(oldNames) > 0 && len(newNames) > 0 {
|
||||
var name string
|
||||
|
||||
newName, oldName := newNames[0], oldNames[0]
|
||||
switch {
|
||||
case oldName < newName:
|
||||
name = oldName
|
||||
oldNames = oldNames[1:]
|
||||
case oldName > newName:
|
||||
name = newName
|
||||
newNames = newNames[1:]
|
||||
case oldNames[0] == newNames[0]:
|
||||
name = oldNames[0]
|
||||
oldNames = oldNames[1:]
|
||||
newNames = newNames[1:]
|
||||
}
|
||||
if !space() {
|
||||
continue
|
||||
}
|
||||
|
||||
ipOld, inOld := oldNameToIP[name]
|
||||
ipNew, inNew := newNameToIP[name]
|
||||
switch {
|
||||
case !inOld:
|
||||
buf.WriteByte('+')
|
||||
printSingleNameIP(buf, name, ipNew)
|
||||
case !inNew:
|
||||
buf.WriteByte('-')
|
||||
printSingleNameIP(buf, name, ipOld)
|
||||
case ipOld != ipNew:
|
||||
buf.WriteByte('-')
|
||||
printSingleNameIP(buf, name, ipOld)
|
||||
buf.WriteByte('+')
|
||||
printSingleNameIP(buf, name, ipNew)
|
||||
}
|
||||
}
|
||||
|
||||
for _, name := range oldNames {
|
||||
if !space() {
|
||||
break
|
||||
}
|
||||
if _, ok := newNameToIP[name]; !ok {
|
||||
buf.WriteByte('-')
|
||||
printSingleNameIP(buf, name, oldNameToIP[name])
|
||||
}
|
||||
}
|
||||
|
||||
for _, name := range newNames {
|
||||
if !space() {
|
||||
break
|
||||
}
|
||||
if _, ok := oldNameToIP[name]; !ok {
|
||||
buf.WriteByte('+')
|
||||
printSingleNameIP(buf, name, newNameToIP[name])
|
||||
}
|
||||
}
|
||||
if !space() {
|
||||
buf.WriteString("... [truncated]\n")
|
||||
}
|
||||
|
||||
return buf.String()
|
||||
}
|
156
net/dns/map_test.go
Normal file
156
net/dns/map_test.go
Normal file
@@ -0,0 +1,156 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
func TestPretty(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dmap *Map
|
||||
want string
|
||||
}{
|
||||
{"empty", NewMap(nil, nil), ""},
|
||||
{
|
||||
"single",
|
||||
NewMap(map[string]netaddr.IP{
|
||||
"hello.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
|
||||
}, nil),
|
||||
"hello.ipn.dev.\t100.101.102.103\n",
|
||||
},
|
||||
{
|
||||
"multiple",
|
||||
NewMap(map[string]netaddr.IP{
|
||||
"test1.domain.": netaddr.IPv4(100, 101, 102, 103),
|
||||
"test2.sub.domain.": netaddr.IPv4(100, 99, 9, 1),
|
||||
}, nil),
|
||||
"test1.domain.\t100.101.102.103\ntest2.sub.domain.\t100.99.9.1\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.dmap.Pretty()
|
||||
if tt.want != got {
|
||||
t.Errorf("want %v; got %v", tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrettyDiffFrom(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
map1 *Map
|
||||
map2 *Map
|
||||
want string
|
||||
}{
|
||||
{
|
||||
"from_empty",
|
||||
nil,
|
||||
NewMap(map[string]netaddr.IP{
|
||||
"test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
|
||||
"test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
|
||||
}, nil),
|
||||
"+test1.ipn.dev.\t100.101.102.103\n+test2.ipn.dev.\t100.103.102.101\n",
|
||||
},
|
||||
{
|
||||
"equal",
|
||||
NewMap(map[string]netaddr.IP{
|
||||
"test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
|
||||
"test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
|
||||
}, nil),
|
||||
NewMap(map[string]netaddr.IP{
|
||||
"test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
|
||||
"test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
|
||||
}, nil),
|
||||
"",
|
||||
},
|
||||
{
|
||||
"changed_ip",
|
||||
NewMap(map[string]netaddr.IP{
|
||||
"test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
|
||||
"test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
|
||||
}, nil),
|
||||
NewMap(map[string]netaddr.IP{
|
||||
"test2.ipn.dev.": netaddr.IPv4(100, 104, 102, 101),
|
||||
"test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
|
||||
}, nil),
|
||||
"-test2.ipn.dev.\t100.103.102.101\n+test2.ipn.dev.\t100.104.102.101\n",
|
||||
},
|
||||
{
|
||||
"new_domain",
|
||||
NewMap(map[string]netaddr.IP{
|
||||
"test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
|
||||
"test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
|
||||
}, nil),
|
||||
NewMap(map[string]netaddr.IP{
|
||||
"test3.ipn.dev.": netaddr.IPv4(100, 105, 106, 107),
|
||||
"test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
|
||||
"test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
|
||||
}, nil),
|
||||
"+test3.ipn.dev.\t100.105.106.107\n",
|
||||
},
|
||||
{
|
||||
"gone_domain",
|
||||
NewMap(map[string]netaddr.IP{
|
||||
"test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
|
||||
"test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
|
||||
}, nil),
|
||||
NewMap(map[string]netaddr.IP{
|
||||
"test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
|
||||
}, nil),
|
||||
"-test2.ipn.dev.\t100.103.102.101\n",
|
||||
},
|
||||
{
|
||||
"mixed",
|
||||
NewMap(map[string]netaddr.IP{
|
||||
"test1.ipn.dev.": netaddr.IPv4(100, 101, 102, 103),
|
||||
"test4.ipn.dev.": netaddr.IPv4(100, 107, 106, 105),
|
||||
"test5.ipn.dev.": netaddr.IPv4(100, 64, 1, 1),
|
||||
"test2.ipn.dev.": netaddr.IPv4(100, 103, 102, 101),
|
||||
}, nil),
|
||||
NewMap(map[string]netaddr.IP{
|
||||
"test2.ipn.dev.": netaddr.IPv4(100, 104, 102, 101),
|
||||
"test1.ipn.dev.": netaddr.IPv4(100, 100, 101, 102),
|
||||
"test3.ipn.dev.": netaddr.IPv4(100, 64, 1, 1),
|
||||
}, nil),
|
||||
"-test1.ipn.dev.\t100.101.102.103\n+test1.ipn.dev.\t100.100.101.102\n" +
|
||||
"-test2.ipn.dev.\t100.103.102.101\n+test2.ipn.dev.\t100.104.102.101\n" +
|
||||
"+test3.ipn.dev.\t100.64.1.1\n-test4.ipn.dev.\t100.107.106.105\n-test5.ipn.dev.\t100.64.1.1\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.map2.PrettyDiffFrom(tt.map1)
|
||||
if tt.want != got {
|
||||
t.Errorf("want %v; got %v", tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("truncated", func(t *testing.T) {
|
||||
small := NewMap(nil, nil)
|
||||
m := map[string]netaddr.IP{}
|
||||
for i := 0; i < 5000; i++ {
|
||||
m[fmt.Sprintf("host%d.ipn.dev.", i)] = netaddr.IPv4(100, 64, 1, 1)
|
||||
}
|
||||
veryBig := NewMap(m, nil)
|
||||
diff := veryBig.PrettyDiffFrom(small)
|
||||
if len(diff) > 3<<10 {
|
||||
t.Errorf("pretty diff too large: %d bytes", len(diff))
|
||||
}
|
||||
if !strings.Contains(diff, "truncated") {
|
||||
t.Errorf("big diff not truncated")
|
||||
}
|
||||
})
|
||||
}
|
25
net/dns/neterr_darwin.go
Normal file
25
net/dns/neterr_darwin.go
Normal file
@@ -0,0 +1,25 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// Avoid allocation when calling errors.Is below
|
||||
// by converting syscall.Errno to error here.
|
||||
var (
|
||||
networkDown error = syscall.ENETDOWN
|
||||
networkUnreachable error = syscall.ENETUNREACH
|
||||
)
|
||||
|
||||
func networkIsDown(err error) bool {
|
||||
return errors.Is(err, networkDown)
|
||||
}
|
||||
|
||||
func networkIsUnreachable(err error) bool {
|
||||
return errors.Is(err, networkUnreachable)
|
||||
}
|
10
net/dns/neterr_other.go
Normal file
10
net/dns/neterr_other.go
Normal file
@@ -0,0 +1,10 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !darwin,!windows
|
||||
|
||||
package dns
|
||||
|
||||
func networkIsDown(err error) bool { return false }
|
||||
func networkIsUnreachable(err error) bool { return false }
|
29
net/dns/neterr_windows.go
Normal file
29
net/dns/neterr_windows.go
Normal file
@@ -0,0 +1,29 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func networkIsDown(err error) bool {
|
||||
if oe, ok := err.(*net.OpError); ok && oe.Op == "write" {
|
||||
if se, ok := oe.Err.(*os.SyscallError); ok {
|
||||
if se.Syscall == "wsasendto" && se.Err == windows.WSAENETUNREACH {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func networkIsUnreachable(err error) bool {
|
||||
// TODO(bradfitz,josharian): something here? what is the
|
||||
// difference between down and unreachable? Add comments.
|
||||
return false
|
||||
}
|
662
net/dns/tsdns.go
Normal file
662
net/dns/tsdns.go
Normal file
@@ -0,0 +1,662 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package dns provides a Resolver capable of resolving
|
||||
// domains on a Tailscale network.
|
||||
package dns
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
dns "golang.org/x/net/dns/dnsmessage"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/net/interfaces"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/dnsname"
|
||||
"tailscale.com/wgengine/monitor"
|
||||
)
|
||||
|
||||
// maxResponseBytes is the maximum size of a response from a Resolver.
|
||||
const maxResponseBytes = 512
|
||||
|
||||
// queueSize is the maximal number of DNS requests that can await polling.
|
||||
// If EnqueueRequest is called when this many requests are already pending,
|
||||
// the request will be dropped to avoid blocking the caller.
|
||||
const queueSize = 64
|
||||
|
||||
// defaultTTL is the TTL of all responses from Resolver.
|
||||
const defaultTTL = 600 * time.Second
|
||||
|
||||
// ErrClosed indicates that the resolver has been closed and readers should exit.
|
||||
var ErrClosed = errors.New("closed")
|
||||
|
||||
var (
|
||||
errFullQueue = errors.New("request queue full")
|
||||
errMapNotSet = errors.New("domain map not set")
|
||||
errNotForwarding = errors.New("forwarding disabled")
|
||||
errNotImplemented = errors.New("query type not implemented")
|
||||
errNotQuery = errors.New("not a DNS query")
|
||||
errNotOurName = errors.New("not a Tailscale DNS name")
|
||||
)
|
||||
|
||||
// Packet represents a DNS payload together with the address of its origin.
|
||||
type Packet struct {
|
||||
// Payload is the application layer DNS payload.
|
||||
// Resolver assumes ownership of the request payload when it is enqueued
|
||||
// and cedes ownership of the response payload when it is returned from NextResponse.
|
||||
Payload []byte
|
||||
// Addr is the source address for a request and the destination address for a response.
|
||||
Addr netaddr.IPPort
|
||||
}
|
||||
|
||||
// Resolver is a DNS resolver for nodes on the Tailscale network,
|
||||
// associating them with domain names of the form <mynode>.<mydomain>.<root>.
|
||||
// If it is asked to resolve a domain that is not of that form,
|
||||
// it delegates to upstream nameservers if any are set.
|
||||
type Resolver struct {
|
||||
logf logger.Logf
|
||||
linkMon *monitor.Mon // or nil
|
||||
unregLinkMon func() // or nil
|
||||
// forwarder forwards requests to upstream nameservers.
|
||||
forwarder *forwarder
|
||||
|
||||
// queue is a buffered channel holding DNS requests queued for resolution.
|
||||
queue chan Packet
|
||||
// responses is an unbuffered channel to which responses are returned.
|
||||
responses chan Packet
|
||||
// errors is an unbuffered channel to which errors are returned.
|
||||
errors chan error
|
||||
// closed signals all goroutines to stop.
|
||||
closed chan struct{}
|
||||
// wg signals when all goroutines have stopped.
|
||||
wg sync.WaitGroup
|
||||
|
||||
// mu guards the following fields from being updated while used.
|
||||
mu sync.Mutex
|
||||
// dnsMap is the map most recently received from the control server.
|
||||
dnsMap *Map
|
||||
}
|
||||
|
||||
// ResolverConfig is the set of configuration options for a Resolver.
|
||||
type ResolverConfig struct {
|
||||
// Logf is the logger to use throughout the Resolver.
|
||||
Logf logger.Logf
|
||||
// Forward determines whether the resolver will forward packets to
|
||||
// nameservers set with SetUpstreams if the domain name is not of a Tailscale node.
|
||||
Forward bool
|
||||
// LinkMonitor optionally provides a link monitor to use to rebind
|
||||
// connections on link changes.
|
||||
// If nil, rebinds are not performend.
|
||||
LinkMonitor *monitor.Mon
|
||||
}
|
||||
|
||||
// NewResolver constructs a resolver associated with the given root domain.
|
||||
// The root domain must be in canonical form (with a trailing period).
|
||||
func NewResolver(config ResolverConfig) *Resolver {
|
||||
r := &Resolver{
|
||||
logf: logger.WithPrefix(config.Logf, "dns: "),
|
||||
linkMon: config.LinkMonitor,
|
||||
queue: make(chan Packet, queueSize),
|
||||
responses: make(chan Packet),
|
||||
errors: make(chan error),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
|
||||
if config.Forward {
|
||||
r.forwarder = newForwarder(r.logf, r.responses)
|
||||
}
|
||||
if r.linkMon != nil {
|
||||
r.unregLinkMon = r.linkMon.RegisterChangeCallback(r.onLinkMonitorChange)
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Resolver) Start() error {
|
||||
if r.forwarder != nil {
|
||||
if err := r.forwarder.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
r.wg.Add(1)
|
||||
go r.poll()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close shuts down the resolver and ensures poll goroutines have exited.
|
||||
// The Resolver cannot be used again after Close is called.
|
||||
func (r *Resolver) Close() {
|
||||
select {
|
||||
case <-r.closed:
|
||||
return
|
||||
default:
|
||||
// continue
|
||||
}
|
||||
close(r.closed)
|
||||
|
||||
if r.unregLinkMon != nil {
|
||||
r.unregLinkMon()
|
||||
}
|
||||
|
||||
if r.forwarder != nil {
|
||||
r.forwarder.Close()
|
||||
}
|
||||
|
||||
r.wg.Wait()
|
||||
}
|
||||
|
||||
func (r *Resolver) onLinkMonitorChange(changed bool, state *interfaces.State) {
|
||||
if !changed {
|
||||
return
|
||||
}
|
||||
if r.forwarder != nil {
|
||||
r.forwarder.rebindFromNetworkChange()
|
||||
}
|
||||
}
|
||||
|
||||
// SetMap sets the resolver's DNS map, taking ownership of it.
|
||||
func (r *Resolver) SetMap(m *Map) {
|
||||
r.mu.Lock()
|
||||
oldMap := r.dnsMap
|
||||
r.dnsMap = m
|
||||
r.mu.Unlock()
|
||||
r.logf("map diff:\n%s", m.PrettyDiffFrom(oldMap))
|
||||
}
|
||||
|
||||
// SetUpstreams sets the addresses of the resolver's
|
||||
// upstream nameservers, taking ownership of the argument.
|
||||
func (r *Resolver) SetUpstreams(upstreams []net.Addr) {
|
||||
if r.forwarder != nil {
|
||||
r.forwarder.setUpstreams(upstreams)
|
||||
}
|
||||
r.logf("set upstreams: %v", upstreams)
|
||||
}
|
||||
|
||||
// EnqueueRequest places the given DNS request in the resolver's queue.
|
||||
// It takes ownership of the payload and does not block.
|
||||
// If the queue is full, the request will be dropped and an error will be returned.
|
||||
func (r *Resolver) EnqueueRequest(request Packet) error {
|
||||
select {
|
||||
case <-r.closed:
|
||||
return ErrClosed
|
||||
case r.queue <- request:
|
||||
return nil
|
||||
default:
|
||||
return errFullQueue
|
||||
}
|
||||
}
|
||||
|
||||
// NextResponse returns a DNS response to a previously enqueued request.
|
||||
// It blocks until a response is available and gives up ownership of the response payload.
|
||||
func (r *Resolver) NextResponse() (Packet, error) {
|
||||
select {
|
||||
case <-r.closed:
|
||||
return Packet{}, ErrClosed
|
||||
case resp := <-r.responses:
|
||||
return resp, nil
|
||||
case err := <-r.errors:
|
||||
return Packet{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve maps a given domain name to the IP address of the host that owns it,
|
||||
// if the IP address conforms to the DNS resource type given by tp (one of A, AAAA, ALL).
|
||||
// The domain name must be in canonical form (with a trailing period).
|
||||
func (r *Resolver) Resolve(domain string, tp dns.Type) (netaddr.IP, dns.RCode, error) {
|
||||
r.mu.Lock()
|
||||
dnsMap := r.dnsMap
|
||||
r.mu.Unlock()
|
||||
|
||||
if dnsMap == nil {
|
||||
return netaddr.IP{}, dns.RCodeServerFailure, errMapNotSet
|
||||
}
|
||||
|
||||
// Reject .onion domains per RFC 7686.
|
||||
if dnsname.HasSuffix(domain, ".onion") {
|
||||
return netaddr.IP{}, dns.RCodeNameError, nil
|
||||
}
|
||||
|
||||
anyHasSuffix := false
|
||||
for _, suffix := range dnsMap.rootDomains {
|
||||
if dnsname.HasSuffix(domain, suffix) {
|
||||
anyHasSuffix = true
|
||||
break
|
||||
}
|
||||
}
|
||||
addr, found := dnsMap.nameToIP[domain]
|
||||
if !found {
|
||||
if !anyHasSuffix {
|
||||
return netaddr.IP{}, dns.RCodeRefused, nil
|
||||
}
|
||||
return netaddr.IP{}, dns.RCodeNameError, nil
|
||||
}
|
||||
|
||||
// Refactoring note: this must happen after we check suffixes,
|
||||
// otherwise we will respond with NOTIMP to requests that should be forwarded.
|
||||
switch tp {
|
||||
case dns.TypeA:
|
||||
if !addr.Is4() {
|
||||
return netaddr.IP{}, dns.RCodeSuccess, nil
|
||||
}
|
||||
return addr, dns.RCodeSuccess, nil
|
||||
case dns.TypeAAAA:
|
||||
if !addr.Is6() {
|
||||
return netaddr.IP{}, dns.RCodeSuccess, nil
|
||||
}
|
||||
return addr, dns.RCodeSuccess, nil
|
||||
case dns.TypeALL:
|
||||
// Answer with whatever we've got.
|
||||
// It could be IPv4, IPv6, or a zero addr.
|
||||
// TODO: Return all available resolutions (A and AAAA, if we have them).
|
||||
return addr, dns.RCodeSuccess, nil
|
||||
|
||||
// Leave some some record types explicitly unimplemented.
|
||||
// These types relate to recursive resolution or special
|
||||
// DNS sematics and might be implemented in the future.
|
||||
case dns.TypeNS, dns.TypeSOA, dns.TypeAXFR, dns.TypeHINFO:
|
||||
return netaddr.IP{}, dns.RCodeNotImplemented, errNotImplemented
|
||||
|
||||
// For everything except for the few types above that are explictly not implemented, return no records.
|
||||
// This is what other DNS systems do: always return NOERROR
|
||||
// without any records whenever the requested record type is unknown.
|
||||
// You can try this with:
|
||||
// dig -t TYPE9824 example.com
|
||||
// and note that NOERROR is returned, despite that record type being made up.
|
||||
default:
|
||||
// no records exist of this type
|
||||
return netaddr.IP{}, dns.RCodeSuccess, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ResolveReverse returns the unique domain name that maps to the given address.
|
||||
// The returned domain name is in canonical form (with a trailing period).
|
||||
func (r *Resolver) ResolveReverse(ip netaddr.IP) (string, dns.RCode, error) {
|
||||
r.mu.Lock()
|
||||
dnsMap := r.dnsMap
|
||||
r.mu.Unlock()
|
||||
|
||||
if dnsMap == nil {
|
||||
return "", dns.RCodeServerFailure, errMapNotSet
|
||||
}
|
||||
name, found := dnsMap.ipToName[ip]
|
||||
if !found {
|
||||
return "", dns.RCodeNameError, nil
|
||||
}
|
||||
return name, dns.RCodeSuccess, nil
|
||||
}
|
||||
|
||||
func (r *Resolver) poll() {
|
||||
defer r.wg.Done()
|
||||
|
||||
var packet Packet
|
||||
for {
|
||||
select {
|
||||
case <-r.closed:
|
||||
return
|
||||
case packet = <-r.queue:
|
||||
// continue
|
||||
}
|
||||
|
||||
out, err := r.respond(packet.Payload)
|
||||
|
||||
if err == errNotOurName {
|
||||
if r.forwarder != nil {
|
||||
err = r.forwarder.forward(packet)
|
||||
if err == nil {
|
||||
// forward will send response into r.responses, nothing to do.
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
err = errNotForwarding
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
select {
|
||||
case <-r.closed:
|
||||
return
|
||||
case r.errors <- err:
|
||||
// continue
|
||||
}
|
||||
} else {
|
||||
packet.Payload = out
|
||||
select {
|
||||
case <-r.closed:
|
||||
return
|
||||
case r.responses <- packet:
|
||||
// continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type response struct {
|
||||
Header dns.Header
|
||||
Question dns.Question
|
||||
// Name is the response to a PTR query.
|
||||
Name string
|
||||
// IP is the response to an A, AAAA, or ALL query.
|
||||
IP netaddr.IP
|
||||
}
|
||||
|
||||
// parseQuery parses the query in given packet into a response struct.
|
||||
func parseQuery(query []byte, resp *response) error {
|
||||
var parser dns.Parser
|
||||
var err error
|
||||
|
||||
resp.Header, err = parser.Start(query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.Header.Response {
|
||||
return errNotQuery
|
||||
}
|
||||
|
||||
resp.Question, err = parser.Question()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// marshalARecord serializes an A record into an active builder.
|
||||
// The caller may continue using the builder following the call.
|
||||
func marshalARecord(name dns.Name, ip netaddr.IP, builder *dns.Builder) error {
|
||||
var answer dns.AResource
|
||||
|
||||
answerHeader := dns.ResourceHeader{
|
||||
Name: name,
|
||||
Type: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
TTL: uint32(defaultTTL / time.Second),
|
||||
}
|
||||
ipbytes := ip.As4()
|
||||
copy(answer.A[:], ipbytes[:])
|
||||
return builder.AResource(answerHeader, answer)
|
||||
}
|
||||
|
||||
// marshalAAAARecord serializes an AAAA record into an active builder.
|
||||
// The caller may continue using the builder following the call.
|
||||
func marshalAAAARecord(name dns.Name, ip netaddr.IP, builder *dns.Builder) error {
|
||||
var answer dns.AAAAResource
|
||||
|
||||
answerHeader := dns.ResourceHeader{
|
||||
Name: name,
|
||||
Type: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
TTL: uint32(defaultTTL / time.Second),
|
||||
}
|
||||
ipbytes := ip.As16()
|
||||
copy(answer.AAAA[:], ipbytes[:])
|
||||
return builder.AAAAResource(answerHeader, answer)
|
||||
}
|
||||
|
||||
// marshalPTRRecord serializes a PTR record into an active builder.
|
||||
// The caller may continue using the builder following the call.
|
||||
func marshalPTRRecord(queryName dns.Name, name string, builder *dns.Builder) error {
|
||||
var answer dns.PTRResource
|
||||
var err error
|
||||
|
||||
answerHeader := dns.ResourceHeader{
|
||||
Name: queryName,
|
||||
Type: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
TTL: uint32(defaultTTL / time.Second),
|
||||
}
|
||||
answer.PTR, err = dns.NewName(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return builder.PTRResource(answerHeader, answer)
|
||||
}
|
||||
|
||||
// marshalResponse serializes the DNS response into a new buffer.
|
||||
func marshalResponse(resp *response) ([]byte, error) {
|
||||
resp.Header.Response = true
|
||||
resp.Header.Authoritative = true
|
||||
if resp.Header.RecursionDesired {
|
||||
resp.Header.RecursionAvailable = true
|
||||
}
|
||||
|
||||
builder := dns.NewBuilder(nil, resp.Header)
|
||||
|
||||
isSuccess := resp.Header.RCode == dns.RCodeSuccess
|
||||
|
||||
if resp.Question.Type != 0 || isSuccess {
|
||||
err := builder.StartQuestions()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = builder.Question(resp.Question)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Only successful responses contain answers.
|
||||
if !isSuccess {
|
||||
return builder.Finish()
|
||||
}
|
||||
|
||||
err := builder.StartAnswers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch resp.Question.Type {
|
||||
case dns.TypeA, dns.TypeAAAA, dns.TypeALL:
|
||||
if resp.IP.Is4() {
|
||||
err = marshalARecord(resp.Question.Name, resp.IP, &builder)
|
||||
} else if resp.IP.Is6() {
|
||||
err = marshalAAAARecord(resp.Question.Name, resp.IP, &builder)
|
||||
}
|
||||
case dns.TypePTR:
|
||||
err = marshalPTRRecord(resp.Question.Name, resp.Name, &builder)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return builder.Finish()
|
||||
}
|
||||
|
||||
const (
|
||||
rdnsv4Suffix = ".in-addr.arpa."
|
||||
rdnsv6Suffix = ".ip6.arpa."
|
||||
)
|
||||
|
||||
// hasRDNSBonjourPrefix reports whether name has a Bonjour Service Prefix..
|
||||
//
|
||||
// https://tools.ietf.org/html/rfc6763 lists
|
||||
// "five special RR names" for Bonjour service discovery:
|
||||
//
|
||||
// b._dns-sd._udp.<domain>.
|
||||
// db._dns-sd._udp.<domain>.
|
||||
// r._dns-sd._udp.<domain>.
|
||||
// dr._dns-sd._udp.<domain>.
|
||||
// lb._dns-sd._udp.<domain>.
|
||||
func hasRDNSBonjourPrefix(s string) bool {
|
||||
// Even the shortest name containing a Bonjour prefix is long,
|
||||
// so check length (cheap) and bail early if possible.
|
||||
if len(s) < len("*._dns-sd._udp.0.0.0.0.in-addr.arpa.") {
|
||||
return false
|
||||
}
|
||||
dot := strings.IndexByte(s, '.')
|
||||
if dot == -1 {
|
||||
return false // shouldn't happen
|
||||
}
|
||||
switch s[:dot] {
|
||||
case "b", "db", "r", "dr", "lb":
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
return strings.HasPrefix(s[dot:], "._dns-sd._udp.")
|
||||
}
|
||||
|
||||
// rawNameToLower converts a raw DNS name to a string, lowercasing it.
|
||||
func rawNameToLower(name []byte) string {
|
||||
var sb strings.Builder
|
||||
sb.Grow(len(name))
|
||||
|
||||
for _, b := range name {
|
||||
if 'A' <= b && b <= 'Z' {
|
||||
b = b - 'A' + 'a'
|
||||
}
|
||||
sb.WriteByte(b)
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// ptrNameToIPv4 transforms a PTR name representing an IPv4 address to said address.
|
||||
// Such names are IPv4 labels in reverse order followed by .in-addr.arpa.
|
||||
// For example,
|
||||
// 4.3.2.1.in-addr.arpa
|
||||
// is transformed to
|
||||
// 1.2.3.4
|
||||
func rdnsNameToIPv4(name string) (ip netaddr.IP, ok bool) {
|
||||
name = strings.TrimSuffix(name, rdnsv4Suffix)
|
||||
ip, err := netaddr.ParseIP(string(name))
|
||||
if err != nil {
|
||||
return netaddr.IP{}, false
|
||||
}
|
||||
if !ip.Is4() {
|
||||
return netaddr.IP{}, false
|
||||
}
|
||||
b := ip.As4()
|
||||
return netaddr.IPv4(b[3], b[2], b[1], b[0]), true
|
||||
}
|
||||
|
||||
// ptrNameToIPv6 transforms a PTR name representing an IPv6 address to said address.
|
||||
// Such names are dot-separated nibbles in reverse order followed by .ip6.arpa.
|
||||
// For example,
|
||||
// b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.
|
||||
// is transformed to
|
||||
// 2001:db8::567:89ab
|
||||
func rdnsNameToIPv6(name string) (ip netaddr.IP, ok bool) {
|
||||
var b [32]byte
|
||||
var ipb [16]byte
|
||||
|
||||
name = strings.TrimSuffix(name, rdnsv6Suffix)
|
||||
// 32 nibbles and 31 dots between them.
|
||||
if len(name) != 63 {
|
||||
return netaddr.IP{}, false
|
||||
}
|
||||
|
||||
// Dots and hex digits alternate.
|
||||
prevDot := true
|
||||
// i ranges over name backward; j ranges over b forward.
|
||||
for i, j := len(name)-1, 0; i >= 0; i-- {
|
||||
thisDot := (name[i] == '.')
|
||||
if prevDot == thisDot {
|
||||
return netaddr.IP{}, false
|
||||
}
|
||||
prevDot = thisDot
|
||||
|
||||
if !thisDot {
|
||||
// This is safe assuming alternation.
|
||||
// We do not check that non-dots are hex digits: hex.Decode below will do that.
|
||||
b[j] = name[i]
|
||||
j++
|
||||
}
|
||||
}
|
||||
|
||||
_, err := hex.Decode(ipb[:], b[:])
|
||||
if err != nil {
|
||||
return netaddr.IP{}, false
|
||||
}
|
||||
|
||||
return netaddr.IPFrom16(ipb), true
|
||||
}
|
||||
|
||||
// respondReverse returns a DNS response to a PTR query.
|
||||
// It is assumed that resp.Question is populated by respond before this is called.
|
||||
func (r *Resolver) respondReverse(query []byte, name string, resp *response) ([]byte, error) {
|
||||
if hasRDNSBonjourPrefix(name) {
|
||||
return nil, errNotOurName
|
||||
}
|
||||
|
||||
var ip netaddr.IP
|
||||
var ok bool
|
||||
switch {
|
||||
case strings.HasSuffix(name, rdnsv4Suffix):
|
||||
ip, ok = rdnsNameToIPv4(name)
|
||||
case strings.HasSuffix(name, rdnsv6Suffix):
|
||||
ip, ok = rdnsNameToIPv6(name)
|
||||
default:
|
||||
return nil, errNotOurName
|
||||
}
|
||||
|
||||
// It is more likely that we failed in parsing the name than that it is actually malformed.
|
||||
// To avoid frustrating users, just log and delegate.
|
||||
if !ok {
|
||||
r.logf("parsing rdns: malformed name: %s", name)
|
||||
return nil, errNotOurName
|
||||
}
|
||||
|
||||
var err error
|
||||
resp.Name, resp.Header.RCode, err = r.ResolveReverse(ip)
|
||||
if err != nil {
|
||||
r.logf("resolving rdns: %v", ip, err)
|
||||
}
|
||||
if resp.Header.RCode == dns.RCodeNameError {
|
||||
return nil, errNotOurName
|
||||
}
|
||||
|
||||
return marshalResponse(resp)
|
||||
}
|
||||
|
||||
// respond returns a DNS response to query if it can be resolved locally.
|
||||
// Otherwise, it returns errNotOurName.
|
||||
func (r *Resolver) respond(query []byte) ([]byte, error) {
|
||||
resp := new(response)
|
||||
|
||||
// ParseQuery is sufficiently fast to run on every DNS packet.
|
||||
// This is considerably simpler than extracting the name by hand
|
||||
// to shave off microseconds in case of delegation.
|
||||
err := parseQuery(query, resp)
|
||||
// We will not return this error: it is the sender's fault.
|
||||
if err != nil {
|
||||
if errors.Is(err, dns.ErrSectionDone) {
|
||||
r.logf("parseQuery(%02x): no DNS questions", query)
|
||||
} else {
|
||||
r.logf("parseQuery(%02x): %v", query, err)
|
||||
}
|
||||
resp.Header.RCode = dns.RCodeFormatError
|
||||
return marshalResponse(resp)
|
||||
}
|
||||
rawName := resp.Question.Name.Data[:resp.Question.Name.Length]
|
||||
name := rawNameToLower(rawName)
|
||||
|
||||
// Always try to handle reverse lookups; delegate inside when not found.
|
||||
// This way, queries for existent nodes do not leak,
|
||||
// but we behave gracefully if non-Tailscale nodes exist in CGNATRange.
|
||||
if resp.Question.Type == dns.TypePTR {
|
||||
return r.respondReverse(query, name, resp)
|
||||
}
|
||||
|
||||
resp.IP, resp.Header.RCode, err = r.Resolve(name, resp.Question.Type)
|
||||
// This return code is special: it requests forwarding.
|
||||
if resp.Header.RCode == dns.RCodeRefused {
|
||||
return nil, errNotOurName
|
||||
}
|
||||
|
||||
// We will not return this error: it is the sender's fault.
|
||||
if err != nil {
|
||||
r.logf("resolving: %v", err)
|
||||
}
|
||||
|
||||
return marshalResponse(resp)
|
||||
}
|
95
net/dns/tsdns_server_test.go
Normal file
95
net/dns/tsdns_server_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"log"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"inet.af/netaddr"
|
||||
)
|
||||
|
||||
// This file exists to isolate the test infrastructure
|
||||
// that depends on github.com/miekg/dns
|
||||
// from the rest, which only depends on dnsmessage.
|
||||
|
||||
var dnsHandleFunc = dns.HandleFunc
|
||||
|
||||
// resolveToIP returns a handler function which responds
|
||||
// to queries of type A it receives with an A record containing ipv4,
|
||||
// to queries of type AAAA with an AAAA record containing ipv6,
|
||||
// to queries of type NS with an NS record containg name.
|
||||
func resolveToIP(ipv4, ipv6 netaddr.IP, ns string) dns.HandlerFunc {
|
||||
return func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(req)
|
||||
|
||||
if len(req.Question) != 1 {
|
||||
panic("not a single-question request")
|
||||
}
|
||||
question := req.Question[0]
|
||||
|
||||
var ans dns.RR
|
||||
switch question.Qtype {
|
||||
case dns.TypeA:
|
||||
ans = &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: question.Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
A: ipv4.IPAddr().IP,
|
||||
}
|
||||
case dns.TypeAAAA:
|
||||
ans = &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: question.Name,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
AAAA: ipv6.IPAddr().IP,
|
||||
}
|
||||
case dns.TypeNS:
|
||||
ans = &dns.NS{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: question.Name,
|
||||
Rrtype: dns.TypeNS,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
Ns: ns,
|
||||
}
|
||||
}
|
||||
|
||||
m.Answer = append(m.Answer, ans)
|
||||
w.WriteMsg(m)
|
||||
}
|
||||
}
|
||||
|
||||
func resolveToNXDOMAIN(w dns.ResponseWriter, req *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetRcode(req, dns.RcodeNameError)
|
||||
w.WriteMsg(m)
|
||||
}
|
||||
|
||||
func serveDNS(tb testing.TB, addr string) (*dns.Server, chan error) {
|
||||
server := &dns.Server{Addr: addr, Net: "udp"}
|
||||
|
||||
waitch := make(chan struct{})
|
||||
server.NotifyStartedFunc = func() { close(waitch) }
|
||||
|
||||
errch := make(chan error, 1)
|
||||
go func() {
|
||||
err := server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Printf("ListenAndServe(%q): %v", addr, err)
|
||||
}
|
||||
errch <- err
|
||||
close(errch)
|
||||
}()
|
||||
|
||||
<-waitch
|
||||
return server, errch
|
||||
}
|
816
net/dns/tsdns_test.go
Normal file
816
net/dns/tsdns_test.go
Normal file
@@ -0,0 +1,816 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
dns "golang.org/x/net/dns/dnsmessage"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/tstest"
|
||||
)
|
||||
|
||||
var testipv4 = netaddr.IPv4(1, 2, 3, 4)
|
||||
var testipv6 = netaddr.IPv6Raw([16]byte{
|
||||
0x00, 0x01, 0x02, 0x03,
|
||||
0x04, 0x05, 0x06, 0x07,
|
||||
0x08, 0x09, 0x0a, 0x0b,
|
||||
0x0c, 0x0d, 0x0e, 0x0f,
|
||||
})
|
||||
|
||||
var dnsMap = NewMap(
|
||||
map[string]netaddr.IP{
|
||||
"test1.ipn.dev.": testipv4,
|
||||
"test2.ipn.dev.": testipv6,
|
||||
},
|
||||
[]string{"ipn.dev."},
|
||||
)
|
||||
|
||||
func dnspacket(domain string, tp dns.Type) []byte {
|
||||
var dnsHeader dns.Header
|
||||
question := dns.Question{
|
||||
Name: dns.MustNewName(domain),
|
||||
Type: tp,
|
||||
Class: dns.ClassINET,
|
||||
}
|
||||
|
||||
builder := dns.NewBuilder(nil, dnsHeader)
|
||||
builder.StartQuestions()
|
||||
builder.Question(question)
|
||||
payload, _ := builder.Finish()
|
||||
|
||||
return payload
|
||||
}
|
||||
|
||||
type dnsResponse struct {
|
||||
ip netaddr.IP
|
||||
name string
|
||||
rcode dns.RCode
|
||||
}
|
||||
|
||||
func unpackResponse(payload []byte) (dnsResponse, error) {
|
||||
var response dnsResponse
|
||||
var parser dns.Parser
|
||||
|
||||
h, err := parser.Start(payload)
|
||||
if err != nil {
|
||||
return response, err
|
||||
}
|
||||
|
||||
if !h.Response {
|
||||
return response, errors.New("not a response")
|
||||
}
|
||||
|
||||
response.rcode = h.RCode
|
||||
if response.rcode != dns.RCodeSuccess {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
err = parser.SkipAllQuestions()
|
||||
if err != nil {
|
||||
return response, err
|
||||
}
|
||||
|
||||
ah, err := parser.AnswerHeader()
|
||||
if err != nil {
|
||||
return response, err
|
||||
}
|
||||
|
||||
switch ah.Type {
|
||||
case dns.TypeA:
|
||||
res, err := parser.AResource()
|
||||
if err != nil {
|
||||
return response, err
|
||||
}
|
||||
response.ip = netaddr.IPv4(res.A[0], res.A[1], res.A[2], res.A[3])
|
||||
case dns.TypeAAAA:
|
||||
res, err := parser.AAAAResource()
|
||||
if err != nil {
|
||||
return response, err
|
||||
}
|
||||
response.ip = netaddr.IPv6Raw(res.AAAA)
|
||||
case dns.TypeNS:
|
||||
res, err := parser.NSResource()
|
||||
if err != nil {
|
||||
return response, err
|
||||
}
|
||||
response.name = res.NS.String()
|
||||
default:
|
||||
return response, errors.New("type not in {A, AAAA, NS}")
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func syncRespond(r *Resolver, query []byte) ([]byte, error) {
|
||||
request := Packet{Payload: query}
|
||||
r.EnqueueRequest(request)
|
||||
resp, err := r.NextResponse()
|
||||
return resp.Payload, err
|
||||
}
|
||||
|
||||
func mustIP(str string) netaddr.IP {
|
||||
ip, err := netaddr.ParseIP(str)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
func TestRDNSNameToIPv4(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantIP netaddr.IP
|
||||
wantOK bool
|
||||
}{
|
||||
{"valid", "4.123.24.1.in-addr.arpa.", netaddr.IPv4(1, 24, 123, 4), true},
|
||||
{"double_dot", "1..2.3.in-addr.arpa.", netaddr.IP{}, false},
|
||||
{"overflow", "1.256.3.4.in-addr.arpa.", netaddr.IP{}, false},
|
||||
{"not_ip", "sub.do.ma.in.in-addr.arpa.", netaddr.IP{}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ip, ok := rdnsNameToIPv4(tt.input)
|
||||
if ok != tt.wantOK {
|
||||
t.Errorf("ok = %v; want %v", ok, tt.wantOK)
|
||||
} else if ok && ip != tt.wantIP {
|
||||
t.Errorf("ip = %v; want %v", ip, tt.wantIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRDNSNameToIPv6(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantIP netaddr.IP
|
||||
wantOK bool
|
||||
}{
|
||||
{
|
||||
"valid",
|
||||
"b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
|
||||
mustIP("2001:db8::567:89ab"),
|
||||
true,
|
||||
},
|
||||
{
|
||||
"double_dot",
|
||||
"b..9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
|
||||
netaddr.IP{},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"double_hex",
|
||||
"b.a.98.0.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
|
||||
netaddr.IP{},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"not_hex",
|
||||
"b.a.g.0.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
|
||||
netaddr.IP{},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ip, ok := rdnsNameToIPv6(tt.input)
|
||||
if ok != tt.wantOK {
|
||||
t.Errorf("ok = %v; want %v", ok, tt.wantOK)
|
||||
} else if ok && ip != tt.wantIP {
|
||||
t.Errorf("ip = %v; want %v", ip, tt.wantIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve(t *testing.T) {
|
||||
r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: false})
|
||||
r.SetMap(dnsMap)
|
||||
|
||||
if err := r.Start(); err != nil {
|
||||
t.Fatalf("start: %v", err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
qname string
|
||||
qtype dns.Type
|
||||
ip netaddr.IP
|
||||
code dns.RCode
|
||||
}{
|
||||
{"ipv4", "test1.ipn.dev.", dns.TypeA, testipv4, dns.RCodeSuccess},
|
||||
{"ipv6", "test2.ipn.dev.", dns.TypeAAAA, testipv6, dns.RCodeSuccess},
|
||||
{"no-ipv6", "test1.ipn.dev.", dns.TypeAAAA, netaddr.IP{}, dns.RCodeSuccess},
|
||||
{"nxdomain", "test3.ipn.dev.", dns.TypeA, netaddr.IP{}, dns.RCodeNameError},
|
||||
{"foreign domain", "google.com.", dns.TypeA, netaddr.IP{}, dns.RCodeRefused},
|
||||
{"all", "test1.ipn.dev.", dns.TypeA, testipv4, dns.RCodeSuccess},
|
||||
{"mx-ipv4", "test1.ipn.dev.", dns.TypeMX, netaddr.IP{}, dns.RCodeSuccess},
|
||||
{"mx-ipv6", "test2.ipn.dev.", dns.TypeMX, netaddr.IP{}, dns.RCodeSuccess},
|
||||
{"mx-nxdomain", "test3.ipn.dev.", dns.TypeMX, netaddr.IP{}, dns.RCodeNameError},
|
||||
{"ns-nxdomain", "test3.ipn.dev.", dns.TypeNS, netaddr.IP{}, dns.RCodeNameError},
|
||||
{"onion-domain", "footest.onion.", dns.TypeA, netaddr.IP{}, dns.RCodeNameError},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ip, code, err := r.Resolve(tt.qname, tt.qtype)
|
||||
if err != nil {
|
||||
t.Errorf("err = %v; want nil", err)
|
||||
}
|
||||
if code != tt.code {
|
||||
t.Errorf("code = %v; want %v", code, tt.code)
|
||||
}
|
||||
// Only check ip for non-err
|
||||
if ip != tt.ip {
|
||||
t.Errorf("ip = %v; want %v", ip, tt.ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveReverse(t *testing.T) {
|
||||
r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: false})
|
||||
r.SetMap(dnsMap)
|
||||
|
||||
if err := r.Start(); err != nil {
|
||||
t.Fatalf("start: %v", err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ip netaddr.IP
|
||||
want string
|
||||
code dns.RCode
|
||||
}{
|
||||
{"ipv4", testipv4, "test1.ipn.dev.", dns.RCodeSuccess},
|
||||
{"ipv6", testipv6, "test2.ipn.dev.", dns.RCodeSuccess},
|
||||
{"nxdomain", netaddr.IPv4(4, 3, 2, 1), "", dns.RCodeNameError},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
name, code, err := r.ResolveReverse(tt.ip)
|
||||
if err != nil {
|
||||
t.Errorf("err = %v; want nil", err)
|
||||
}
|
||||
if code != tt.code {
|
||||
t.Errorf("code = %v; want %v", code, tt.code)
|
||||
}
|
||||
if name != tt.want {
|
||||
t.Errorf("ip = %v; want %v", name, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ipv6Works() bool {
|
||||
c, err := net.Listen("tcp", "[::1]:0")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
c.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
func TestDelegate(t *testing.T) {
|
||||
tstest.ResourceCheck(t)
|
||||
|
||||
if !ipv6Works() {
|
||||
t.Skip("skipping test that requires localhost IPv6")
|
||||
}
|
||||
|
||||
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
|
||||
dnsHandleFunc("nxdomain.site.", resolveToNXDOMAIN)
|
||||
|
||||
v4server, v4errch := serveDNS(t, "127.0.0.1:0")
|
||||
v6server, v6errch := serveDNS(t, "[::1]:0")
|
||||
|
||||
defer func() {
|
||||
if err := <-v4errch; err != nil {
|
||||
t.Errorf("v4 server error: %v", err)
|
||||
}
|
||||
if err := <-v6errch; err != nil {
|
||||
t.Errorf("v6 server error: %v", err)
|
||||
}
|
||||
}()
|
||||
if v4server != nil {
|
||||
defer v4server.Shutdown()
|
||||
}
|
||||
if v6server != nil {
|
||||
defer v6server.Shutdown()
|
||||
}
|
||||
|
||||
if v4server == nil || v6server == nil {
|
||||
// There is an error in at least one of the channels
|
||||
// and we cannot proceed; return to see it.
|
||||
return
|
||||
}
|
||||
|
||||
r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: true})
|
||||
r.SetMap(dnsMap)
|
||||
r.SetUpstreams([]net.Addr{
|
||||
v4server.PacketConn.LocalAddr(),
|
||||
v6server.PacketConn.LocalAddr(),
|
||||
})
|
||||
|
||||
if err := r.Start(); err != nil {
|
||||
t.Fatalf("start: %v", err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
tests := []struct {
|
||||
title string
|
||||
query []byte
|
||||
response dnsResponse
|
||||
}{
|
||||
{
|
||||
"ipv4",
|
||||
dnspacket("test.site.", dns.TypeA),
|
||||
dnsResponse{ip: testipv4, rcode: dns.RCodeSuccess},
|
||||
},
|
||||
{
|
||||
"ipv6",
|
||||
dnspacket("test.site.", dns.TypeAAAA),
|
||||
dnsResponse{ip: testipv6, rcode: dns.RCodeSuccess},
|
||||
},
|
||||
{
|
||||
"ns",
|
||||
dnspacket("test.site.", dns.TypeNS),
|
||||
dnsResponse{name: "dns.test.site.", rcode: dns.RCodeSuccess},
|
||||
},
|
||||
{
|
||||
"nxdomain",
|
||||
dnspacket("nxdomain.site.", dns.TypeA),
|
||||
dnsResponse{rcode: dns.RCodeNameError},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.title, func(t *testing.T) {
|
||||
payload, err := syncRespond(r, tt.query)
|
||||
if err != nil {
|
||||
t.Errorf("err = %v; want nil", err)
|
||||
return
|
||||
}
|
||||
response, err := unpackResponse(payload)
|
||||
if err != nil {
|
||||
t.Errorf("extract: err = %v; want nil (in %x)", err, payload)
|
||||
return
|
||||
}
|
||||
if response.rcode != tt.response.rcode {
|
||||
t.Errorf("rcode = %v; want %v", response.rcode, tt.response.rcode)
|
||||
}
|
||||
if response.ip != tt.response.ip {
|
||||
t.Errorf("ip = %v; want %v", response.ip, tt.response.ip)
|
||||
}
|
||||
if response.name != tt.response.name {
|
||||
t.Errorf("name = %v; want %v", response.name, tt.response.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDelegateCollision(t *testing.T) {
|
||||
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
|
||||
|
||||
server, errch := serveDNS(t, "127.0.0.1:0")
|
||||
defer func() {
|
||||
if err := <-errch; err != nil {
|
||||
t.Errorf("server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if server == nil {
|
||||
return
|
||||
}
|
||||
defer server.Shutdown()
|
||||
|
||||
r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: true})
|
||||
r.SetMap(dnsMap)
|
||||
r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()})
|
||||
|
||||
if err := r.Start(); err != nil {
|
||||
t.Fatalf("start: %v", err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
packets := []struct {
|
||||
qname string
|
||||
qtype dns.Type
|
||||
addr netaddr.IPPort
|
||||
}{
|
||||
{"test.site.", dns.TypeA, netaddr.IPPort{IP: netaddr.IPv4(1, 1, 1, 1), Port: 1001}},
|
||||
{"test.site.", dns.TypeAAAA, netaddr.IPPort{IP: netaddr.IPv4(1, 1, 1, 1), Port: 1002}},
|
||||
}
|
||||
|
||||
// packets will have the same dns txid.
|
||||
for _, p := range packets {
|
||||
payload := dnspacket(p.qname, p.qtype)
|
||||
req := Packet{Payload: payload, Addr: p.addr}
|
||||
err := r.EnqueueRequest(req)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Despite the txid collision, the answer(s) should still match the query.
|
||||
resp, err := r.NextResponse()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
var p dns.Parser
|
||||
_, err = p.Start(resp.Payload)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
err = p.SkipAllQuestions()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
ans, err := p.AllAnswers()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
var wantType dns.Type
|
||||
switch ans[0].Body.(type) {
|
||||
case *dns.AResource:
|
||||
wantType = dns.TypeA
|
||||
case *dns.AAAAResource:
|
||||
wantType = dns.TypeAAAA
|
||||
default:
|
||||
t.Errorf("unexpected answer type: %T", ans[0].Body)
|
||||
}
|
||||
|
||||
for _, p := range packets {
|
||||
if p.qtype == wantType && p.addr != resp.Addr {
|
||||
t.Errorf("addr = %v; want %v", resp.Addr, p.addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentSetMap(t *testing.T) {
|
||||
r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: false})
|
||||
|
||||
if err := r.Start(); err != nil {
|
||||
t.Fatalf("start: %v", err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
// This is purely to ensure that Resolve does not race with SetMap.
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
r.SetMap(dnsMap)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
r.Resolve("test1.ipn.dev", dns.TypeA)
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestConcurrentSetUpstreams(t *testing.T) {
|
||||
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
|
||||
|
||||
server, errch := serveDNS(t, "127.0.0.1:0")
|
||||
defer func() {
|
||||
if err := <-errch; err != nil {
|
||||
t.Errorf("server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if server == nil {
|
||||
return
|
||||
}
|
||||
defer server.Shutdown()
|
||||
|
||||
r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: true})
|
||||
r.SetMap(dnsMap)
|
||||
|
||||
if err := r.Start(); err != nil {
|
||||
t.Fatalf("start: %v", err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
packet := dnspacket("test.site.", dns.TypeA)
|
||||
// This is purely to ensure that delegation does not race with SetUpstreams.
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()})
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
syncRespond(r, packet)
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
var allResponse = []byte{
|
||||
0x00, 0x00, // transaction id: 0
|
||||
0x84, 0x00, // flags: response, authoritative, no error
|
||||
0x00, 0x01, // one question
|
||||
0x00, 0x01, // one answer
|
||||
0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
|
||||
// Question:
|
||||
0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
|
||||
0x00, 0xff, 0x00, 0x01, // type ALL, class IN
|
||||
// Answer:
|
||||
0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
|
||||
0x00, 0x01, 0x00, 0x01, // type A, class IN
|
||||
0x00, 0x00, 0x02, 0x58, // TTL: 600
|
||||
0x00, 0x04, // length: 4 bytes
|
||||
0x01, 0x02, 0x03, 0x04, // A: 1.2.3.4
|
||||
}
|
||||
|
||||
var ipv4Response = []byte{
|
||||
0x00, 0x00, // transaction id: 0
|
||||
0x84, 0x00, // flags: response, authoritative, no error
|
||||
0x00, 0x01, // one question
|
||||
0x00, 0x01, // one answer
|
||||
0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
|
||||
// Question:
|
||||
0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
|
||||
0x00, 0x01, 0x00, 0x01, // type A, class IN
|
||||
// Answer:
|
||||
0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
|
||||
0x00, 0x01, 0x00, 0x01, // type A, class IN
|
||||
0x00, 0x00, 0x02, 0x58, // TTL: 600
|
||||
0x00, 0x04, // length: 4 bytes
|
||||
0x01, 0x02, 0x03, 0x04, // A: 1.2.3.4
|
||||
}
|
||||
|
||||
var ipv6Response = []byte{
|
||||
0x00, 0x00, // transaction id: 0
|
||||
0x84, 0x00, // flags: response, authoritative, no error
|
||||
0x00, 0x01, // one question
|
||||
0x00, 0x01, // one answer
|
||||
0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
|
||||
// Question:
|
||||
0x05, 0x74, 0x65, 0x73, 0x74, 0x32, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
|
||||
0x00, 0x1c, 0x00, 0x01, // type AAAA, class IN
|
||||
// Answer:
|
||||
0x05, 0x74, 0x65, 0x73, 0x74, 0x32, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
|
||||
0x00, 0x1c, 0x00, 0x01, // type AAAA, class IN
|
||||
0x00, 0x00, 0x02, 0x58, // TTL: 600
|
||||
0x00, 0x10, // length: 16 bytes
|
||||
// AAAA: 0001:0203:0405:0607:0809:0A0B:0C0D:0E0F
|
||||
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0xb, 0xc, 0xd, 0xe, 0xf,
|
||||
}
|
||||
|
||||
var ipv4UppercaseResponse = []byte{
|
||||
0x00, 0x00, // transaction id: 0
|
||||
0x84, 0x00, // flags: response, authoritative, no error
|
||||
0x00, 0x01, // one question
|
||||
0x00, 0x01, // one answer
|
||||
0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
|
||||
// Question:
|
||||
0x05, 0x54, 0x45, 0x53, 0x54, 0x31, 0x03, 0x49, 0x50, 0x4e, 0x03, 0x44, 0x45, 0x56, 0x00, // name
|
||||
0x00, 0x01, 0x00, 0x01, // type A, class IN
|
||||
// Answer:
|
||||
0x05, 0x54, 0x45, 0x53, 0x54, 0x31, 0x03, 0x49, 0x50, 0x4e, 0x03, 0x44, 0x45, 0x56, 0x00, // name
|
||||
0x00, 0x01, 0x00, 0x01, // type A, class IN
|
||||
0x00, 0x00, 0x02, 0x58, // TTL: 600
|
||||
0x00, 0x04, // length: 4 bytes
|
||||
0x01, 0x02, 0x03, 0x04, // A: 1.2.3.4
|
||||
}
|
||||
|
||||
var ptrResponse = []byte{
|
||||
0x00, 0x00, // transaction id: 0
|
||||
0x84, 0x00, // flags: response, authoritative, no error
|
||||
0x00, 0x01, // one question
|
||||
0x00, 0x01, // one answer
|
||||
0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
|
||||
// Question: 4.3.2.1.in-addr.arpa
|
||||
0x01, 0x34, 0x01, 0x33, 0x01, 0x32, 0x01, 0x31, 0x07,
|
||||
0x69, 0x6e, 0x2d, 0x61, 0x64, 0x64, 0x72, 0x04, 0x61, 0x72, 0x70, 0x61, 0x00,
|
||||
0x00, 0x0c, 0x00, 0x01, // type PTR, class IN
|
||||
// Answer: 4.3.2.1.in-addr.arpa
|
||||
0x01, 0x34, 0x01, 0x33, 0x01, 0x32, 0x01, 0x31, 0x07,
|
||||
0x69, 0x6e, 0x2d, 0x61, 0x64, 0x64, 0x72, 0x04, 0x61, 0x72, 0x70, 0x61, 0x00,
|
||||
0x00, 0x0c, 0x00, 0x01, // type PTR, class IN
|
||||
0x00, 0x00, 0x02, 0x58, // TTL: 600
|
||||
0x00, 0x0f, // length: 15 bytes
|
||||
// PTR: test1.ipn.dev
|
||||
0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00,
|
||||
}
|
||||
|
||||
var ptrResponse6 = []byte{
|
||||
0x00, 0x00, // transaction id: 0
|
||||
0x84, 0x00, // flags: response, authoritative, no error
|
||||
0x00, 0x01, // one question
|
||||
0x00, 0x01, // one answer
|
||||
0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
|
||||
// Question: f.0.e.0.d.0.c.0.b.0.a.0.9.0.8.0.7.0.6.0.5.0.4.0.3.0.2.0.1.0.0.0.ip6.arpa
|
||||
0x01, 0x66, 0x01, 0x30, 0x01, 0x65, 0x01, 0x30,
|
||||
0x01, 0x64, 0x01, 0x30, 0x01, 0x63, 0x01, 0x30,
|
||||
0x01, 0x62, 0x01, 0x30, 0x01, 0x61, 0x01, 0x30,
|
||||
0x01, 0x39, 0x01, 0x30, 0x01, 0x38, 0x01, 0x30,
|
||||
0x01, 0x37, 0x01, 0x30, 0x01, 0x36, 0x01, 0x30,
|
||||
0x01, 0x35, 0x01, 0x30, 0x01, 0x34, 0x01, 0x30,
|
||||
0x01, 0x33, 0x01, 0x30, 0x01, 0x32, 0x01, 0x30,
|
||||
0x01, 0x31, 0x01, 0x30, 0x01, 0x30, 0x01, 0x30,
|
||||
0x03, 0x69, 0x70, 0x36,
|
||||
0x04, 0x61, 0x72, 0x70, 0x61, 0x00,
|
||||
0x00, 0x0c, 0x00, 0x01, // type PTR, class IN6
|
||||
// Answer: f.0.e.0.d.0.c.0.b.0.a.0.9.0.8.0.7.0.6.0.5.0.4.0.3.0.2.0.1.0.0.0.ip6.arpa
|
||||
0x01, 0x66, 0x01, 0x30, 0x01, 0x65, 0x01, 0x30,
|
||||
0x01, 0x64, 0x01, 0x30, 0x01, 0x63, 0x01, 0x30,
|
||||
0x01, 0x62, 0x01, 0x30, 0x01, 0x61, 0x01, 0x30,
|
||||
0x01, 0x39, 0x01, 0x30, 0x01, 0x38, 0x01, 0x30,
|
||||
0x01, 0x37, 0x01, 0x30, 0x01, 0x36, 0x01, 0x30,
|
||||
0x01, 0x35, 0x01, 0x30, 0x01, 0x34, 0x01, 0x30,
|
||||
0x01, 0x33, 0x01, 0x30, 0x01, 0x32, 0x01, 0x30,
|
||||
0x01, 0x31, 0x01, 0x30, 0x01, 0x30, 0x01, 0x30,
|
||||
0x03, 0x69, 0x70, 0x36,
|
||||
0x04, 0x61, 0x72, 0x70, 0x61, 0x00,
|
||||
0x00, 0x0c, 0x00, 0x01, // type PTR, class IN
|
||||
0x00, 0x00, 0x02, 0x58, // TTL: 600
|
||||
0x00, 0x0f, // length: 15 bytes
|
||||
// PTR: test2.ipn.dev
|
||||
0x05, 0x74, 0x65, 0x73, 0x74, 0x32, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00,
|
||||
}
|
||||
|
||||
var nxdomainResponse = []byte{
|
||||
0x00, 0x00, // transaction id: 0
|
||||
0x84, 0x03, // flags: response, authoritative, error: nxdomain
|
||||
0x00, 0x01, // one question
|
||||
0x00, 0x00, // no answers
|
||||
0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
|
||||
// Question:
|
||||
0x05, 0x74, 0x65, 0x73, 0x74, 0x33, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
|
||||
0x00, 0x01, 0x00, 0x01, // type A, class IN
|
||||
}
|
||||
|
||||
var emptyResponse = []byte{
|
||||
0x00, 0x00, // transaction id: 0
|
||||
0x84, 0x00, // flags: response, authoritative, no error
|
||||
0x00, 0x01, // one question
|
||||
0x00, 0x00, // no answers
|
||||
0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
|
||||
// Question:
|
||||
0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
|
||||
0x00, 0x1c, 0x00, 0x01, // type AAAA, class IN
|
||||
}
|
||||
|
||||
func TestFull(t *testing.T) {
|
||||
r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: false})
|
||||
r.SetMap(dnsMap)
|
||||
|
||||
if err := r.Start(); err != nil {
|
||||
t.Fatalf("start: %v", err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
// One full packet and one error packet
|
||||
tests := []struct {
|
||||
name string
|
||||
request []byte
|
||||
response []byte
|
||||
}{
|
||||
{"all", dnspacket("test1.ipn.dev.", dns.TypeALL), allResponse},
|
||||
{"ipv4", dnspacket("test1.ipn.dev.", dns.TypeA), ipv4Response},
|
||||
{"ipv6", dnspacket("test2.ipn.dev.", dns.TypeAAAA), ipv6Response},
|
||||
{"no-ipv6", dnspacket("test1.ipn.dev.", dns.TypeAAAA), emptyResponse},
|
||||
{"upper", dnspacket("TEST1.IPN.DEV.", dns.TypeA), ipv4UppercaseResponse},
|
||||
{"ptr", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR), ptrResponse},
|
||||
{"ptr", dnspacket("f.0.e.0.d.0.c.0.b.0.a.0.9.0.8.0.7.0.6.0.5.0.4.0.3.0.2.0.1.0.0.0.ip6.arpa.",
|
||||
dns.TypePTR), ptrResponse6},
|
||||
{"nxdomain", dnspacket("test3.ipn.dev.", dns.TypeA), nxdomainResponse},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
response, err := syncRespond(r, tt.request)
|
||||
if err != nil {
|
||||
t.Errorf("err = %v; want nil", err)
|
||||
}
|
||||
if !bytes.Equal(response, tt.response) {
|
||||
t.Errorf("response = %x; want %x", response, tt.response)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllocs(t *testing.T) {
|
||||
r := NewResolver(ResolverConfig{Logf: t.Logf, Forward: false})
|
||||
r.SetMap(dnsMap)
|
||||
|
||||
if err := r.Start(); err != nil {
|
||||
t.Fatalf("start: %v", err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
// It is seemingly pointless to test allocs in the delegate path,
|
||||
// as dialer.Dial -> Read -> Write alone comprise 12 allocs.
|
||||
tests := []struct {
|
||||
name string
|
||||
query []byte
|
||||
want int
|
||||
}{
|
||||
// Name lowercasing and response slice created by dns.NewBuilder.
|
||||
{"forward", dnspacket("test1.ipn.dev.", dns.TypeA), 2},
|
||||
// 3 extra allocs in rdnsNameToIPv4 and one in marshalPTRRecord (dns.NewName).
|
||||
{"reverse", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR), 5},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
allocs := testing.AllocsPerRun(100, func() {
|
||||
syncRespond(r, tt.query)
|
||||
})
|
||||
if int(allocs) > tt.want {
|
||||
t.Errorf("%s: allocs = %v; want %v", tt.name, allocs, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrimRDNSBonjourPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
in string
|
||||
want bool
|
||||
}{
|
||||
{"b._dns-sd._udp.0.10.20.172.in-addr.arpa.", true},
|
||||
{"db._dns-sd._udp.0.10.20.172.in-addr.arpa.", true},
|
||||
{"r._dns-sd._udp.0.10.20.172.in-addr.arpa.", true},
|
||||
{"dr._dns-sd._udp.0.10.20.172.in-addr.arpa.", true},
|
||||
{"lb._dns-sd._udp.0.10.20.172.in-addr.arpa.", true},
|
||||
{"qq._dns-sd._udp.0.10.20.172.in-addr.arpa.", false},
|
||||
{"0.10.20.172.in-addr.arpa.", false},
|
||||
{"i-have-no-dot", false},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
got := hasRDNSBonjourPrefix(test.in)
|
||||
if got != test.want {
|
||||
t.Errorf("trimRDNSBonjourPrefix(%q) = %v, want %v", test.in, got, test.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFull(b *testing.B) {
|
||||
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
|
||||
|
||||
server, errch := serveDNS(b, "127.0.0.1:0")
|
||||
defer func() {
|
||||
if err := <-errch; err != nil {
|
||||
b.Errorf("server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if server == nil {
|
||||
return
|
||||
}
|
||||
defer server.Shutdown()
|
||||
|
||||
r := NewResolver(ResolverConfig{Logf: b.Logf, Forward: true})
|
||||
r.SetMap(dnsMap)
|
||||
r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()})
|
||||
|
||||
if err := r.Start(); err != nil {
|
||||
b.Fatalf("start: %v", err)
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
request []byte
|
||||
}{
|
||||
{"forward", dnspacket("test1.ipn.dev.", dns.TypeA)},
|
||||
{"reverse", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR)},
|
||||
{"delegated", dnspacket("test.site.", dns.TypeA)},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
b.Run(tt.name, func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
syncRespond(r, tt.request)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarshalResponseFormatError(t *testing.T) {
|
||||
resp := new(response)
|
||||
resp.Header.RCode = dns.RCodeFormatError
|
||||
v, err := marshalResponse(resp)
|
||||
if err != nil {
|
||||
t.Errorf("marshal error: %v", err)
|
||||
}
|
||||
t.Logf("response: %q", v)
|
||||
}
|
Reference in New Issue
Block a user