mirror of
https://github.com/tailscale/tailscale.git
synced 2025-02-22 12:58:37 +00:00
tsdns: delegate requests asynchronously (#687)
Signed-Off-By: Dmytro Shynkevych <dmytro@tailscale.com>
This commit is contained in:
parent
a583e498b0
commit
1af70e2468
325
wgengine/tsdns/forwarder.go
Normal file
325
wgengine/tsdns/forwarder.go
Normal file
@ -0,0 +1,325 @@
|
|||||||
|
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package tsdns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"hash/crc32"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"inet.af/netaddr"
|
||||||
|
"tailscale.com/types/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// headerBytes is the number of bytes in a DNS message header.
|
||||||
|
const headerBytes = 12
|
||||||
|
|
||||||
|
// forwardQueueSize is the maximal number of requests that can be pending delegation.
|
||||||
|
// Note that this is distinct from the number of requests that are pending a response,
|
||||||
|
// which is not limited (except by txid collisions).
|
||||||
|
const forwardQueueSize = 64
|
||||||
|
|
||||||
|
// connCount is the number of UDP connections to use for forwarding.
|
||||||
|
const connCount = 32
|
||||||
|
|
||||||
|
const (
|
||||||
|
// cleanupInterval is the interval between purged of timed-out entries from txMap.
|
||||||
|
cleanupInterval = 30 * time.Second
|
||||||
|
// responseTimeout is the maximal amount of time to wait for a DNS response.
|
||||||
|
responseTimeout = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
var errNoUpstreams = errors.New("upstream nameservers not set")
|
||||||
|
|
||||||
|
var aLongTimeAgo = time.Unix(0, 1)
|
||||||
|
|
||||||
|
type forwardedPacket struct {
|
||||||
|
payload []byte
|
||||||
|
dst net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
type forwardingRecord struct {
|
||||||
|
src netaddr.IPPort
|
||||||
|
createdAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// txid identifies a DNS transaction.
|
||||||
|
//
|
||||||
|
// As the standard DNS Request ID is only 16 bits, we extend it:
|
||||||
|
// the lower 32 bits are the zero-extended bits of the DNS Request ID;
|
||||||
|
// the upper 32 bits are the CRC32 checksum of the first question in the request.
|
||||||
|
// This makes probability of txid collision negligible.
|
||||||
|
type txid uint64
|
||||||
|
|
||||||
|
// getTxID computes the txid of the given DNS packet.
|
||||||
|
func getTxID(packet []byte) txid {
|
||||||
|
if len(packet) < headerBytes {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsid := binary.BigEndian.Uint16(packet[0:2])
|
||||||
|
qcount := binary.BigEndian.Uint16(packet[4:6])
|
||||||
|
if qcount == 0 {
|
||||||
|
return txid(dnsid)
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := headerBytes
|
||||||
|
for i := uint16(0); i < qcount; i++ {
|
||||||
|
// Note: this relies on the fact that names are not compressed in questions,
|
||||||
|
// so they are guaranteed to end with a NUL byte.
|
||||||
|
//
|
||||||
|
// Justification:
|
||||||
|
// RFC 1035 doesn't seem to explicitly prohibit compressing names in questions,
|
||||||
|
// but this is exceedingly unlikely to be done in practice. A DNS request
|
||||||
|
// with multiple questions is ill-defined (which questions do the header flags apply to?)
|
||||||
|
// and a single question would have to contain a pointer to an *answer*,
|
||||||
|
// which would be excessively smart, pointless (an answer can just as well refer to the question)
|
||||||
|
// and perhaps even prohibited: a draft RFC (draft-ietf-dnsind-local-compression-05) states:
|
||||||
|
//
|
||||||
|
// > It is important that these pointers always point backwards.
|
||||||
|
//
|
||||||
|
// This is said in summarizing RFC 1035, although that phrase does not appear in the original RFC.
|
||||||
|
// Additionally, (https://cr.yp.to/djbdns/notes.html) states:
|
||||||
|
//
|
||||||
|
// > The precise rule is that a name can be compressed if it is a response owner name,
|
||||||
|
// > the name in NS data, the name in CNAME data, the name in PTR data, the name in MX data,
|
||||||
|
// > or one of the names in SOA data.
|
||||||
|
namebytes := bytes.IndexByte(packet[offset:], 0)
|
||||||
|
// ... | name | NUL | type | class
|
||||||
|
// ?? 1 2 2
|
||||||
|
offset = offset + namebytes + 5
|
||||||
|
if len(packet) < offset {
|
||||||
|
// Corrupt packet; don't crash.
|
||||||
|
return txid(dnsid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hash := crc32.ChecksumIEEE(packet[headerBytes:offset])
|
||||||
|
return (txid(hash) << 32) | txid(dnsid)
|
||||||
|
}
|
||||||
|
|
||||||
|
// forwarder forwards DNS packets to a number of upstream nameservers.
|
||||||
|
type forwarder struct {
|
||||||
|
logf logger.Logf
|
||||||
|
|
||||||
|
// queue is the queue for delegated packets.
|
||||||
|
queue chan forwardedPacket
|
||||||
|
// responses is a channel by which responses are returned.
|
||||||
|
responses chan Packet
|
||||||
|
// closed signals all goroutines to stop.
|
||||||
|
closed chan struct{}
|
||||||
|
// wg signals when all goroutines have stopped.
|
||||||
|
wg sync.WaitGroup
|
||||||
|
|
||||||
|
// conns are the UDP connections used for forwarding.
|
||||||
|
// A random one is selected for each request, regardless of the target upstream.
|
||||||
|
conns []*net.UDPConn
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
// upstreams are the nameserver addresses that should be used for forwarding.
|
||||||
|
upstreams []net.Addr
|
||||||
|
// txMap maps DNS txids to active forwarding records.
|
||||||
|
txMap map[txid]forwardingRecord
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rand.Seed(time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
|
func newForwarder(logf logger.Logf, responses chan Packet) *forwarder {
|
||||||
|
return &forwarder{
|
||||||
|
logf: logger.WithPrefix(logf, "forward: "),
|
||||||
|
responses: responses,
|
||||||
|
queue: make(chan forwardedPacket, forwardQueueSize),
|
||||||
|
closed: make(chan struct{}),
|
||||||
|
conns: make([]*net.UDPConn, connCount),
|
||||||
|
txMap: make(map[txid]forwardingRecord),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *forwarder) Start() error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
for i := range f.conns {
|
||||||
|
f.conns[i], err = net.ListenUDP("udp", nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f.wg.Add(connCount + 2)
|
||||||
|
for idx, conn := range f.conns {
|
||||||
|
go f.recv(uint16(idx), conn)
|
||||||
|
}
|
||||||
|
go f.send()
|
||||||
|
go f.cleanMap()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *forwarder) Close() {
|
||||||
|
select {
|
||||||
|
case <-f.closed:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
// continue
|
||||||
|
}
|
||||||
|
close(f.closed)
|
||||||
|
|
||||||
|
for _, conn := range f.conns {
|
||||||
|
conn.SetDeadline(aLongTimeAgo)
|
||||||
|
}
|
||||||
|
|
||||||
|
f.wg.Wait()
|
||||||
|
|
||||||
|
for _, conn := range f.conns {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *forwarder) setUpstreams(upstreams []net.Addr) {
|
||||||
|
f.mu.Lock()
|
||||||
|
f.upstreams = upstreams
|
||||||
|
f.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *forwarder) send() {
|
||||||
|
defer f.wg.Done()
|
||||||
|
|
||||||
|
var packet forwardedPacket
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-f.closed:
|
||||||
|
return
|
||||||
|
case packet = <-f.queue:
|
||||||
|
// continue
|
||||||
|
}
|
||||||
|
|
||||||
|
connIdx := rand.Intn(connCount)
|
||||||
|
conn := f.conns[connIdx]
|
||||||
|
_, err := conn.WriteTo(packet.payload, packet.dst)
|
||||||
|
if err != nil {
|
||||||
|
// Do not log errors due to expired deadline.
|
||||||
|
if !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||||
|
f.logf("send: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *forwarder) recv(connIdx uint16, conn *net.UDPConn) {
|
||||||
|
defer f.wg.Done()
|
||||||
|
|
||||||
|
for {
|
||||||
|
out := make([]byte, maxResponseBytes)
|
||||||
|
n, err := conn.Read(out)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
// Do not log errors due to expired deadline.
|
||||||
|
if !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||||
|
f.logf("recv: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if n < headerBytes {
|
||||||
|
f.logf("recv: packet too small (%d bytes)", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
out = out[:n]
|
||||||
|
txid := getTxID(out)
|
||||||
|
|
||||||
|
f.mu.Lock()
|
||||||
|
|
||||||
|
record, found := f.txMap[txid]
|
||||||
|
// At most one nameserver will return a response:
|
||||||
|
// the first one to do so will delete txid from the map.
|
||||||
|
if !found {
|
||||||
|
f.mu.Unlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
delete(f.txMap, txid)
|
||||||
|
|
||||||
|
f.mu.Unlock()
|
||||||
|
|
||||||
|
packet := Packet{
|
||||||
|
Payload: out,
|
||||||
|
Addr: record.src,
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-f.closed:
|
||||||
|
return
|
||||||
|
case f.responses <- packet:
|
||||||
|
// continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanMap periodically deletes timed-out forwarding records from f.txMap to bound growth.
|
||||||
|
func (f *forwarder) cleanMap() {
|
||||||
|
defer f.wg.Done()
|
||||||
|
|
||||||
|
t := time.NewTicker(cleanupInterval)
|
||||||
|
defer t.Stop()
|
||||||
|
|
||||||
|
var now time.Time
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-f.closed:
|
||||||
|
return
|
||||||
|
case now = <-t.C:
|
||||||
|
// continue
|
||||||
|
}
|
||||||
|
|
||||||
|
f.mu.Lock()
|
||||||
|
for k, v := range f.txMap {
|
||||||
|
if now.Sub(v.createdAt) > responseTimeout {
|
||||||
|
delete(f.txMap, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
f.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// forward forwards the query to all upstream nameservers and returns the first response.
|
||||||
|
func (f *forwarder) forward(query Packet) error {
|
||||||
|
txid := getTxID(query.Payload)
|
||||||
|
|
||||||
|
f.mu.Lock()
|
||||||
|
|
||||||
|
upstreams := f.upstreams
|
||||||
|
if len(upstreams) == 0 {
|
||||||
|
f.mu.Unlock()
|
||||||
|
return errNoUpstreams
|
||||||
|
}
|
||||||
|
f.txMap[txid] = forwardingRecord{
|
||||||
|
src: query.Addr,
|
||||||
|
createdAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
f.mu.Unlock()
|
||||||
|
|
||||||
|
packet := forwardedPacket{
|
||||||
|
payload: query.Payload,
|
||||||
|
}
|
||||||
|
for _, upstream := range upstreams {
|
||||||
|
packet.dst = upstream
|
||||||
|
select {
|
||||||
|
case <-f.closed:
|
||||||
|
return ErrClosed
|
||||||
|
case f.queue <- packet:
|
||||||
|
// continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
@ -8,29 +8,24 @@ package tsdns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
dns "golang.org/x/net/dns/dnsmessage"
|
dns "golang.org/x/net/dns/dnsmessage"
|
||||||
"inet.af/netaddr"
|
"inet.af/netaddr"
|
||||||
"tailscale.com/net/netns"
|
|
||||||
"tailscale.com/types/logger"
|
"tailscale.com/types/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// maxResponseSize is the maximum size of a response from a Resolver.
|
// maxResponseBytes is the maximum size of a response from a Resolver.
|
||||||
const maxResponseSize = 512
|
const maxResponseBytes = 512
|
||||||
|
|
||||||
// queueSize is the maximal number of DNS requests that can be pending at a time.
|
// pendingQueueSize is the maximal number of DNS requests that can await polling.
|
||||||
// If EnqueueRequest is called when this many requests are already pending,
|
// If EnqueueRequest is called when this many requests are already pending,
|
||||||
// the request will be dropped to avoid blocking the caller.
|
// the request will be dropped to avoid blocking the caller.
|
||||||
const queueSize = 8
|
const pendingQueueSize = 64
|
||||||
|
|
||||||
// delegateTimeout is the maximal amount of time Resolver will wait
|
|
||||||
// for upstream nameservers to process a query.
|
|
||||||
const delegateTimeout = 5 * time.Second
|
|
||||||
|
|
||||||
// defaultTTL is the TTL of all responses from Resolver.
|
// defaultTTL is the TTL of all responses from Resolver.
|
||||||
const defaultTTL = 600 * time.Second
|
const defaultTTL = 600 * time.Second
|
||||||
@ -39,12 +34,12 @@ const defaultTTL = 600 * time.Second
|
|||||||
var ErrClosed = errors.New("closed")
|
var ErrClosed = errors.New("closed")
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errAllFailed = errors.New("all upstream nameservers failed")
|
|
||||||
errFullQueue = errors.New("request queue full")
|
errFullQueue = errors.New("request queue full")
|
||||||
errNoNameservers = errors.New("no upstream nameservers set")
|
|
||||||
errMapNotSet = errors.New("domain map not set")
|
errMapNotSet = errors.New("domain map not set")
|
||||||
|
errNotForwarding = errors.New("forwarding disabled")
|
||||||
errNotImplemented = errors.New("query type not implemented")
|
errNotImplemented = errors.New("query type not implemented")
|
||||||
errNotQuery = errors.New("not a DNS query")
|
errNotQuery = errors.New("not a DNS query")
|
||||||
|
errNotOurName = errors.New("not a Tailscale DNS name")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Packet represents a DNS payload together with the address of its origin.
|
// Packet represents a DNS payload together with the address of its origin.
|
||||||
@ -63,57 +58,69 @@ type Packet struct {
|
|||||||
// it delegates to upstream nameservers if any are set.
|
// it delegates to upstream nameservers if any are set.
|
||||||
type Resolver struct {
|
type Resolver struct {
|
||||||
logf logger.Logf
|
logf logger.Logf
|
||||||
|
// rootDomain is <root> in <mynode>.<mydomain>.<root>.
|
||||||
// The asynchronous interface is due to the fact that resolution may potentially
|
rootDomain []byte
|
||||||
// block for a long time (if the upstream nameserver is slow to reach).
|
// forwarder is
|
||||||
|
forwarder *forwarder
|
||||||
|
|
||||||
// queue is a buffered channel holding DNS requests queued for resolution.
|
// queue is a buffered channel holding DNS requests queued for resolution.
|
||||||
queue chan Packet
|
queue chan Packet
|
||||||
// responses is an unbuffered channel to which responses are sent.
|
// responses is an unbuffered channel to which responses are returned.
|
||||||
responses chan Packet
|
responses chan Packet
|
||||||
// errors is an unbuffered channel to which errors are sent.
|
// errors is an unbuffered channel to which errors are returned.
|
||||||
errors chan error
|
errors chan error
|
||||||
// closed notifies the poll goroutines to stop.
|
// closed signals all goroutines to stop.
|
||||||
closed chan struct{}
|
closed chan struct{}
|
||||||
// pollGroup signals when all poll goroutines have stopped.
|
// wg signals when all goroutines have stopped.
|
||||||
pollGroup sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
|
|
||||||
// rootDomain is <root> in <mynode>.<mydomain>.<root>.
|
|
||||||
rootDomain []byte
|
|
||||||
|
|
||||||
// dialer is the netns.Dialer used for delegation.
|
|
||||||
dialer netns.Dialer
|
|
||||||
|
|
||||||
// mu guards the following fields from being updated while used.
|
// mu guards the following fields from being updated while used.
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
// dnsMap is the map most recently received from the control server.
|
// dnsMap is the map most recently received from the control server.
|
||||||
dnsMap *Map
|
dnsMap *Map
|
||||||
// nameservers is the list of nameserver addresses that should be used
|
}
|
||||||
// if the received query is not for a Tailscale node.
|
|
||||||
// The addresses are strings of the form ip:port, as expected by Dial.
|
// ResolverConfig is the set of configuration options for a Resolver.
|
||||||
nameservers []string
|
type ResolverConfig struct {
|
||||||
|
// Logf is the logger to use throughout the Resolver.
|
||||||
|
Logf logger.Logf
|
||||||
|
// RootDomain is the domain whose subdomains will be resolved locally as Tailscale nodes.
|
||||||
|
RootDomain string
|
||||||
|
// Forward determines whether the resolver will forward packets to
|
||||||
|
// nameservers set with SetUpstreams if the domain name is not of a Tailscale node.
|
||||||
|
Forward bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewResolver constructs a resolver associated with the given root domain.
|
// NewResolver constructs a resolver associated with the given root domain.
|
||||||
// The root domain must be in canonical form (with a trailing period).
|
// The root domain must be in canonical form (with a trailing period).
|
||||||
func NewResolver(logf logger.Logf, rootDomain string) *Resolver {
|
func NewResolver(config ResolverConfig) *Resolver {
|
||||||
r := &Resolver{
|
r := &Resolver{
|
||||||
logf: logger.WithPrefix(logf, "tsdns: "),
|
logf: logger.WithPrefix(config.Logf, "tsdns: "),
|
||||||
queue: make(chan Packet, queueSize),
|
queue: make(chan Packet, pendingQueueSize),
|
||||||
responses: make(chan Packet),
|
responses: make(chan Packet),
|
||||||
errors: make(chan error),
|
errors: make(chan error),
|
||||||
closed: make(chan struct{}),
|
closed: make(chan struct{}),
|
||||||
rootDomain: []byte(rootDomain),
|
rootDomain: []byte(config.RootDomain),
|
||||||
dialer: netns.NewDialer(),
|
}
|
||||||
|
|
||||||
|
if config.Forward {
|
||||||
|
r.forwarder = newForwarder(r.logf, r.responses)
|
||||||
}
|
}
|
||||||
|
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Resolver) Start() {
|
func (r *Resolver) Start() error {
|
||||||
// TODO(dmytro): spawn more than one goroutine? They block on delegation.
|
if r.forwarder != nil {
|
||||||
r.pollGroup.Add(1)
|
if err := r.forwarder.Start(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
r.wg.Add(1)
|
||||||
go r.poll()
|
go r.poll()
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close shuts down the resolver and ensures poll goroutines have exited.
|
// Close shuts down the resolver and ensures poll goroutines have exited.
|
||||||
@ -126,7 +133,12 @@ func (r *Resolver) Close() {
|
|||||||
// continue
|
// continue
|
||||||
}
|
}
|
||||||
close(r.closed)
|
close(r.closed)
|
||||||
r.pollGroup.Wait()
|
|
||||||
|
if r.forwarder != nil {
|
||||||
|
r.forwarder.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
r.wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetMap sets the resolver's DNS map, taking ownership of it.
|
// SetMap sets the resolver's DNS map, taking ownership of it.
|
||||||
@ -138,14 +150,12 @@ func (r *Resolver) SetMap(m *Map) {
|
|||||||
r.logf("map diff:\n%s", m.PrettyDiffFrom(oldMap))
|
r.logf("map diff:\n%s", m.PrettyDiffFrom(oldMap))
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetUpstreamNameservers sets the addresses of the resolver's
|
// SetUpstreams sets the addresses of the resolver's
|
||||||
// upstream nameservers, taking ownership of the argument.
|
// upstream nameservers, taking ownership of the argument.
|
||||||
// The addresses should be strings of the form ip:port,
|
func (r *Resolver) SetUpstreams(upstreams []net.Addr) {
|
||||||
// matching what Dial("udp", addr) expects as addr.
|
if r.forwarder != nil {
|
||||||
func (r *Resolver) SetNameservers(nameservers []string) {
|
r.forwarder.setUpstreams(upstreams)
|
||||||
r.mu.Lock()
|
}
|
||||||
r.nameservers = nameservers
|
|
||||||
r.mu.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// EnqueueRequest places the given DNS request in the resolver's queue.
|
// EnqueueRequest places the given DNS request in the resolver's queue.
|
||||||
@ -153,6 +163,8 @@ func (r *Resolver) SetNameservers(nameservers []string) {
|
|||||||
// If the queue is full, the request will be dropped and an error will be returned.
|
// If the queue is full, the request will be dropped and an error will be returned.
|
||||||
func (r *Resolver) EnqueueRequest(request Packet) error {
|
func (r *Resolver) EnqueueRequest(request Packet) error {
|
||||||
select {
|
select {
|
||||||
|
case <-r.closed:
|
||||||
|
return ErrClosed
|
||||||
case r.queue <- request:
|
case r.queue <- request:
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
@ -164,12 +176,12 @@ func (r *Resolver) EnqueueRequest(request Packet) error {
|
|||||||
// It blocks until a response is available and gives up ownership of the response payload.
|
// It blocks until a response is available and gives up ownership of the response payload.
|
||||||
func (r *Resolver) NextResponse() (Packet, error) {
|
func (r *Resolver) NextResponse() (Packet, error) {
|
||||||
select {
|
select {
|
||||||
|
case <-r.closed:
|
||||||
|
return Packet{}, ErrClosed
|
||||||
case resp := <-r.responses:
|
case resp := <-r.responses:
|
||||||
return resp, nil
|
return resp, nil
|
||||||
case err := <-r.errors:
|
case err := <-r.errors:
|
||||||
return Packet{}, err
|
return Packet{}, err
|
||||||
case <-r.closed:
|
|
||||||
return Packet{}, ErrClosed
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -209,114 +221,50 @@ func (r *Resolver) ResolveReverse(ip netaddr.IP) (string, dns.RCode, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Resolver) poll() {
|
func (r *Resolver) poll() {
|
||||||
defer r.pollGroup.Done()
|
defer r.wg.Done()
|
||||||
|
|
||||||
var (
|
var packet Packet
|
||||||
packet Packet
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case packet = <-r.queue:
|
|
||||||
// continue
|
|
||||||
case <-r.closed:
|
case <-r.closed:
|
||||||
return
|
return
|
||||||
|
case packet = <-r.queue:
|
||||||
|
// continue
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := r.respond(packet.Payload)
|
||||||
|
|
||||||
|
if err == errNotOurName {
|
||||||
|
if r.forwarder != nil {
|
||||||
|
err = r.forwarder.forward(packet)
|
||||||
|
if err == nil {
|
||||||
|
// forward will send response into r.responses, nothing to do.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
err = errNotForwarding
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
packet.Payload, err = r.respond(packet.Payload)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
select {
|
select {
|
||||||
|
case <-r.closed:
|
||||||
|
return
|
||||||
case r.errors <- err:
|
case r.errors <- err:
|
||||||
// continue
|
// continue
|
||||||
case <-r.closed:
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
packet.Payload = out
|
||||||
select {
|
select {
|
||||||
case r.responses <- packet:
|
|
||||||
// continue
|
|
||||||
case <-r.closed:
|
case <-r.closed:
|
||||||
return
|
return
|
||||||
|
case r.responses <- packet:
|
||||||
|
// continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// queryServer obtains a DNS response by querying the given server.
|
|
||||||
func (r *Resolver) queryServer(ctx context.Context, server string, query []byte) ([]byte, error) {
|
|
||||||
conn, err := r.dialer.DialContext(ctx, "udp", server)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
// Interrupt the current operation when the context is cancelled.
|
|
||||||
go func() {
|
|
||||||
<-ctx.Done()
|
|
||||||
conn.SetDeadline(time.Unix(1, 0))
|
|
||||||
}()
|
|
||||||
|
|
||||||
_, err = conn.Write(query)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
out := make([]byte, maxResponseSize)
|
|
||||||
n, err := conn.Read(out)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return out[:n], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// delegate forwards the query to all upstream nameservers and returns the first response.
|
|
||||||
func (r *Resolver) delegate(query []byte) ([]byte, error) {
|
|
||||||
r.mu.Lock()
|
|
||||||
nameservers := r.nameservers
|
|
||||||
r.mu.Unlock()
|
|
||||||
|
|
||||||
if len(nameservers) == 0 {
|
|
||||||
return nil, errNoNameservers
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), delegateTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Common case, don't spawn goroutines.
|
|
||||||
if len(nameservers) == 1 {
|
|
||||||
return r.queryServer(ctx, nameservers[0], query)
|
|
||||||
}
|
|
||||||
|
|
||||||
datach := make(chan []byte)
|
|
||||||
for _, server := range nameservers {
|
|
||||||
go func(s string) {
|
|
||||||
resp, err := r.queryServer(ctx, s, query)
|
|
||||||
// Only print errors not due to cancelation after first response.
|
|
||||||
if err != nil && ctx.Err() != context.Canceled {
|
|
||||||
r.logf("querying %s: %v", s, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
datach <- resp
|
|
||||||
}(server)
|
|
||||||
}
|
|
||||||
|
|
||||||
var response []byte
|
|
||||||
for range nameservers {
|
|
||||||
cur := <-datach
|
|
||||||
if cur != nil && response == nil {
|
|
||||||
// Received first successful response
|
|
||||||
response = cur
|
|
||||||
cancel()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if response == nil {
|
|
||||||
return nil, errAllFailed
|
|
||||||
}
|
|
||||||
return response, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type response struct {
|
type response struct {
|
||||||
Header dns.Header
|
Header dns.Header
|
||||||
Question dns.Question
|
Question dns.Question
|
||||||
@ -517,8 +465,6 @@ func rdnsNameToIPv6(name []byte) (ip netaddr.IP, ok bool) {
|
|||||||
func (r *Resolver) respondReverse(query []byte, resp *response) ([]byte, error) {
|
func (r *Resolver) respondReverse(query []byte, resp *response) ([]byte, error) {
|
||||||
name := resp.Question.Name.Data[:resp.Question.Name.Length]
|
name := resp.Question.Name.Data[:resp.Question.Name.Length]
|
||||||
|
|
||||||
shouldDelegate := false
|
|
||||||
|
|
||||||
var ip netaddr.IP
|
var ip netaddr.IP
|
||||||
var ok bool
|
var ok bool
|
||||||
var err error
|
var err error
|
||||||
@ -528,7 +474,7 @@ func (r *Resolver) respondReverse(query []byte, resp *response) ([]byte, error)
|
|||||||
case bytes.HasSuffix(name, rdnsv6Suffix):
|
case bytes.HasSuffix(name, rdnsv6Suffix):
|
||||||
ip, ok = rdnsNameToIPv6(name)
|
ip, ok = rdnsNameToIPv6(name)
|
||||||
default:
|
default:
|
||||||
shouldDelegate = true
|
return nil, errNotOurName
|
||||||
}
|
}
|
||||||
|
|
||||||
// It is more likely that we failed in parsing the name than that it is actually malformed.
|
// It is more likely that we failed in parsing the name than that it is actually malformed.
|
||||||
@ -536,25 +482,15 @@ func (r *Resolver) respondReverse(query []byte, resp *response) ([]byte, error)
|
|||||||
if !ok {
|
if !ok {
|
||||||
// Without this conversion, escape analysis rules that resp escapes.
|
// Without this conversion, escape analysis rules that resp escapes.
|
||||||
r.logf("parsing rdns: malformed name: %s", resp.Question.Name.String())
|
r.logf("parsing rdns: malformed name: %s", resp.Question.Name.String())
|
||||||
shouldDelegate = true
|
return nil, errNotOurName
|
||||||
}
|
}
|
||||||
|
|
||||||
if !shouldDelegate {
|
resp.Name, resp.Header.RCode, err = r.ResolveReverse(ip)
|
||||||
resp.Name, resp.Header.RCode, err = r.ResolveReverse(ip)
|
if err != nil {
|
||||||
if err != nil {
|
r.logf("resolving rdns: %v", ip, err)
|
||||||
r.logf("resolving rdns: %v", ip, err)
|
|
||||||
}
|
|
||||||
shouldDelegate = (resp.Header.RCode == dns.RCodeNameError)
|
|
||||||
}
|
}
|
||||||
|
if resp.Header.RCode == dns.RCodeNameError {
|
||||||
if shouldDelegate {
|
return nil, errNotOurName
|
||||||
out, err := r.delegate(query)
|
|
||||||
if err != nil {
|
|
||||||
r.logf("delegating rdns: %v", err)
|
|
||||||
resp.Header.RCode = dns.RCodeServerFailure
|
|
||||||
return marshalResponse(resp)
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return marshalResponse(resp)
|
return marshalResponse(resp)
|
||||||
@ -586,13 +522,7 @@ func (r *Resolver) respond(query []byte) ([]byte, error) {
|
|||||||
// We do this on bytes because Name.String() allocates.
|
// We do this on bytes because Name.String() allocates.
|
||||||
rawName := resp.Question.Name.Data[:resp.Question.Name.Length]
|
rawName := resp.Question.Name.Data[:resp.Question.Name.Length]
|
||||||
if !bytes.HasSuffix(rawName, r.rootDomain) {
|
if !bytes.HasSuffix(rawName, r.rootDomain) {
|
||||||
out, err := r.delegate(query)
|
return nil, errNotOurName
|
||||||
if err != nil {
|
|
||||||
r.logf("delegating: %v", err)
|
|
||||||
resp.Header.RCode = dns.RCodeServerFailure
|
|
||||||
return marshalResponse(resp)
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch resp.Question.Type {
|
switch resp.Question.Type {
|
||||||
|
@ -7,11 +7,13 @@ package tsdns
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
dns "golang.org/x/net/dns/dnsmessage"
|
dns "golang.org/x/net/dns/dnsmessage"
|
||||||
"inet.af/netaddr"
|
"inet.af/netaddr"
|
||||||
|
"tailscale.com/tstest"
|
||||||
)
|
)
|
||||||
|
|
||||||
var testipv4 = netaddr.IPv4(1, 2, 3, 4)
|
var testipv4 = netaddr.IPv4(1, 2, 3, 4)
|
||||||
@ -178,9 +180,13 @@ func TestRDNSNameToIPv6(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestResolve(t *testing.T) {
|
func TestResolve(t *testing.T) {
|
||||||
r := NewResolver(t.Logf, "ipn.dev")
|
r := NewResolver(ResolverConfig{Logf: t.Logf, RootDomain: "ipn.dev.", Forward: false})
|
||||||
r.SetMap(dnsMap)
|
r.SetMap(dnsMap)
|
||||||
r.Start()
|
|
||||||
|
if err := r.Start(); err != nil {
|
||||||
|
t.Fatalf("start: %v", err)
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@ -212,9 +218,13 @@ func TestResolve(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestResolveReverse(t *testing.T) {
|
func TestResolveReverse(t *testing.T) {
|
||||||
r := NewResolver(t.Logf, "ipn.dev")
|
r := NewResolver(ResolverConfig{Logf: t.Logf, RootDomain: "ipn.dev.", Forward: false})
|
||||||
r.SetMap(dnsMap)
|
r.SetMap(dnsMap)
|
||||||
r.Start()
|
|
||||||
|
if err := r.Start(); err != nil {
|
||||||
|
t.Fatalf("start: %v", err)
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@ -244,6 +254,9 @@ func TestResolveReverse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDelegate(t *testing.T) {
|
func TestDelegate(t *testing.T) {
|
||||||
|
rc := tstest.NewResourceCheck()
|
||||||
|
defer rc.Assert(t)
|
||||||
|
|
||||||
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6))
|
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6))
|
||||||
dnsHandleFunc("nxdomain.site.", resolveToNXDOMAIN)
|
dnsHandleFunc("nxdomain.site.", resolveToNXDOMAIN)
|
||||||
|
|
||||||
@ -271,12 +284,16 @@ func TestDelegate(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
r := NewResolver(t.Logf, "ipn.dev")
|
r := NewResolver(ResolverConfig{Logf: t.Logf, RootDomain: "ipn.dev.", Forward: true})
|
||||||
r.SetNameservers([]string{
|
r.SetUpstreams([]net.Addr{
|
||||||
v4server.PacketConn.LocalAddr().String(),
|
v4server.PacketConn.LocalAddr(),
|
||||||
v6server.PacketConn.LocalAddr().String(),
|
v6server.PacketConn.LocalAddr(),
|
||||||
})
|
})
|
||||||
r.Start()
|
|
||||||
|
if err := r.Start(); err != nil {
|
||||||
|
t.Fatalf("start: %v", err)
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@ -311,9 +328,92 @@ func TestDelegate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDelegateCollision(t *testing.T) {
|
||||||
|
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6))
|
||||||
|
|
||||||
|
server, errch := serveDNS("127.0.0.1:0")
|
||||||
|
defer func() {
|
||||||
|
if err := <-errch; err != nil {
|
||||||
|
t.Errorf("server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if server == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer server.Shutdown()
|
||||||
|
|
||||||
|
r := NewResolver(ResolverConfig{Logf: t.Logf, RootDomain: "ipn.dev.", Forward: true})
|
||||||
|
r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()})
|
||||||
|
|
||||||
|
if err := r.Start(); err != nil {
|
||||||
|
t.Fatalf("start: %v", err)
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
|
packets := []struct {
|
||||||
|
qname string
|
||||||
|
qtype dns.Type
|
||||||
|
addr netaddr.IPPort
|
||||||
|
}{
|
||||||
|
{"test.site.", dns.TypeA, netaddr.IPPort{IP: netaddr.IPv4(1, 1, 1, 1), Port: 1001}},
|
||||||
|
{"test.site.", dns.TypeAAAA, netaddr.IPPort{IP: netaddr.IPv4(1, 1, 1, 1), Port: 1002}},
|
||||||
|
}
|
||||||
|
|
||||||
|
// packets will have the same dns txid.
|
||||||
|
for _, p := range packets {
|
||||||
|
payload := dnspacket(p.qname, p.qtype)
|
||||||
|
req := Packet{Payload: payload, Addr: p.addr}
|
||||||
|
err := r.EnqueueRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Despite the txid collision, the answer(s) should still match the query.
|
||||||
|
resp, err := r.NextResponse()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var p dns.Parser
|
||||||
|
_, err = p.Start(resp.Payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
err = p.SkipAllQuestions()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
ans, err := p.AllAnswers()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var wantType dns.Type
|
||||||
|
switch ans[0].Body.(type) {
|
||||||
|
case *dns.AResource:
|
||||||
|
wantType = dns.TypeA
|
||||||
|
case *dns.AAAAResource:
|
||||||
|
wantType = dns.TypeAAAA
|
||||||
|
default:
|
||||||
|
t.Errorf("unexpected answer type: %T", ans[0].Body)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range packets {
|
||||||
|
if p.qtype == wantType && p.addr != resp.Addr {
|
||||||
|
t.Errorf("addr = %v; want %v", resp.Addr, p.addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestConcurrentSetMap(t *testing.T) {
|
func TestConcurrentSetMap(t *testing.T) {
|
||||||
r := NewResolver(t.Logf, "ipn.dev.")
|
r := NewResolver(ResolverConfig{Logf: t.Logf, RootDomain: "ipn.dev.", Forward: false})
|
||||||
r.Start()
|
|
||||||
|
if err := r.Start(); err != nil {
|
||||||
|
t.Fatalf("start: %v", err)
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
// This is purely to ensure that Resolve does not race with SetMap.
|
// This is purely to ensure that Resolve does not race with SetMap.
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
@ -329,17 +429,36 @@ func TestConcurrentSetMap(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConcurrentSetNameservers(t *testing.T) {
|
func TestConcurrentSetUpstreams(t *testing.T) {
|
||||||
r := NewResolver(t.Logf, "ipn.dev.")
|
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6))
|
||||||
r.Start()
|
|
||||||
packet := dnspacket("google.com.", dns.TypeA)
|
|
||||||
|
|
||||||
// This is purely to ensure that delegation does not race with SetNameservers.
|
server, errch := serveDNS("127.0.0.1:0")
|
||||||
|
defer func() {
|
||||||
|
if err := <-errch; err != nil {
|
||||||
|
t.Errorf("server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if server == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer server.Shutdown()
|
||||||
|
|
||||||
|
r := NewResolver(ResolverConfig{Logf: t.Logf, RootDomain: "ipn.dev.", Forward: true})
|
||||||
|
r.SetMap(dnsMap)
|
||||||
|
|
||||||
|
if err := r.Start(); err != nil {
|
||||||
|
t.Fatalf("start: %v", err)
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
|
packet := dnspacket("test.site.", dns.TypeA)
|
||||||
|
// This is purely to ensure that delegation does not race with SetUpstreams.
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(2)
|
wg.Add(2)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
r.SetNameservers([]string{"9.9.9.9:53"})
|
r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()})
|
||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
@ -415,9 +534,13 @@ var nxdomainResponse = []byte{
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFull(t *testing.T) {
|
func TestFull(t *testing.T) {
|
||||||
r := NewResolver(t.Logf, "ipn.dev.")
|
r := NewResolver(ResolverConfig{Logf: t.Logf, RootDomain: "ipn.dev.", Forward: false})
|
||||||
r.SetMap(dnsMap)
|
r.SetMap(dnsMap)
|
||||||
r.Start()
|
|
||||||
|
if err := r.Start(); err != nil {
|
||||||
|
t.Fatalf("start: %v", err)
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
// One full packet and one error packet
|
// One full packet and one error packet
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@ -445,9 +568,13 @@ func TestFull(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAllocs(t *testing.T) {
|
func TestAllocs(t *testing.T) {
|
||||||
r := NewResolver(t.Logf, "ipn.dev.")
|
r := NewResolver(ResolverConfig{Logf: t.Logf, RootDomain: "ipn.dev.", Forward: false})
|
||||||
r.SetMap(dnsMap)
|
r.SetMap(dnsMap)
|
||||||
r.Start()
|
|
||||||
|
if err := r.Start(); err != nil {
|
||||||
|
t.Fatalf("start: %v", err)
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
// It is seemingly pointless to test allocs in the delegate path,
|
// It is seemingly pointless to test allocs in the delegate path,
|
||||||
// as dialer.Dial -> Read -> Write alone comprise 12 allocs.
|
// as dialer.Dial -> Read -> Write alone comprise 12 allocs.
|
||||||
@ -473,9 +600,28 @@ func TestAllocs(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkFull(b *testing.B) {
|
func BenchmarkFull(b *testing.B) {
|
||||||
r := NewResolver(b.Logf, "ipn.dev.")
|
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6))
|
||||||
|
|
||||||
|
server, errch := serveDNS("127.0.0.1:0")
|
||||||
|
defer func() {
|
||||||
|
if err := <-errch; err != nil {
|
||||||
|
b.Errorf("server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if server == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer server.Shutdown()
|
||||||
|
|
||||||
|
r := NewResolver(ResolverConfig{Logf: b.Logf, RootDomain: "ipn.dev.", Forward: true})
|
||||||
r.SetMap(dnsMap)
|
r.SetMap(dnsMap)
|
||||||
r.Start()
|
r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()})
|
||||||
|
|
||||||
|
if err := r.Start(); err != nil {
|
||||||
|
b.Fatalf("start: %v", err)
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@ -483,7 +629,7 @@ func BenchmarkFull(b *testing.B) {
|
|||||||
}{
|
}{
|
||||||
{"forward", dnspacket("test1.ipn.dev.", dns.TypeA)},
|
{"forward", dnspacket("test1.ipn.dev.", dns.TypeA)},
|
||||||
{"reverse", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR)},
|
{"reverse", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR)},
|
||||||
{"nxdomain", dnspacket("test3.ipn.dev.", dns.TypeA)},
|
{"delegated", dnspacket("test.site.", dns.TypeA)},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -201,13 +201,18 @@ func NewUserspaceEngineAdvanced(conf EngineConfig) (Engine, error) {
|
|||||||
func newUserspaceEngineAdvanced(conf EngineConfig) (_ Engine, reterr error) {
|
func newUserspaceEngineAdvanced(conf EngineConfig) (_ Engine, reterr error) {
|
||||||
logf := conf.Logf
|
logf := conf.Logf
|
||||||
|
|
||||||
|
rconf := tsdns.ResolverConfig{
|
||||||
|
Logf: conf.Logf,
|
||||||
|
RootDomain: magicDNSDomain,
|
||||||
|
Forward: true,
|
||||||
|
}
|
||||||
e := &userspaceEngine{
|
e := &userspaceEngine{
|
||||||
timeNow: time.Now,
|
timeNow: time.Now,
|
||||||
logf: logf,
|
logf: logf,
|
||||||
reqCh: make(chan struct{}, 1),
|
reqCh: make(chan struct{}, 1),
|
||||||
waitCh: make(chan struct{}),
|
waitCh: make(chan struct{}),
|
||||||
tundev: tstun.WrapTUN(logf, conf.TUN),
|
tundev: tstun.WrapTUN(logf, conf.TUN),
|
||||||
resolver: tsdns.NewResolver(logf, magicDNSDomain),
|
resolver: tsdns.NewResolver(rconf),
|
||||||
pingers: make(map[wgcfg.Key]*pinger),
|
pingers: make(map[wgcfg.Key]*pinger),
|
||||||
}
|
}
|
||||||
e.localAddrs.Store(map[packet.IP]bool{})
|
e.localAddrs.Store(map[packet.IP]bool{})
|
||||||
@ -849,11 +854,16 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config)
|
|||||||
if routerChanged {
|
if routerChanged {
|
||||||
if routerCfg.DNS.Proxied {
|
if routerCfg.DNS.Proxied {
|
||||||
ips := routerCfg.DNS.Nameservers
|
ips := routerCfg.DNS.Nameservers
|
||||||
nameservers := make([]string, len(ips))
|
upstreams := make([]net.Addr, len(ips))
|
||||||
for i, ip := range ips {
|
for i, ip := range ips {
|
||||||
nameservers[i] = net.JoinHostPort(ip.String(), "53")
|
stdIP := ip.IPAddr()
|
||||||
|
upstreams[i] = &net.UDPAddr{
|
||||||
|
IP: stdIP.IP,
|
||||||
|
Port: 53,
|
||||||
|
Zone: stdIP.Zone,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
e.resolver.SetNameservers(nameservers)
|
e.resolver.SetUpstreams(upstreams)
|
||||||
routerCfg.DNS.Nameservers = []netaddr.IP{tsaddr.TailscaleServiceIP()}
|
routerCfg.DNS.Nameservers = []netaddr.IP{tsaddr.TailscaleServiceIP()}
|
||||||
}
|
}
|
||||||
e.logf("wgengine: Reconfig: configuring router")
|
e.logf("wgengine: Reconfig: configuring router")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user