tsdns: delegate requests asynchronously (#687)

Signed-Off-By: Dmytro Shynkevych <dmytro@tailscale.com>
This commit is contained in:
Dmytro Shynkevych
2020-08-19 15:39:25 -04:00
committed by GitHub
parent a583e498b0
commit 1af70e2468
4 changed files with 600 additions and 189 deletions

View File

@@ -8,29 +8,24 @@ package tsdns
import (
"bytes"
"context"
"encoding/hex"
"errors"
"net"
"sync"
"time"
dns "golang.org/x/net/dns/dnsmessage"
"inet.af/netaddr"
"tailscale.com/net/netns"
"tailscale.com/types/logger"
)
// maxResponseSize is the maximum size of a response from a Resolver.
const maxResponseSize = 512
// maxResponseBytes is the maximum size of a response from a Resolver.
const maxResponseBytes = 512
// queueSize is the maximal number of DNS requests that can be pending at a time.
// pendingQueueSize is the maximal number of DNS requests that can await polling.
// If EnqueueRequest is called when this many requests are already pending,
// the request will be dropped to avoid blocking the caller.
const queueSize = 8
// delegateTimeout is the maximal amount of time Resolver will wait
// for upstream nameservers to process a query.
const delegateTimeout = 5 * time.Second
const pendingQueueSize = 64
// defaultTTL is the TTL of all responses from Resolver.
const defaultTTL = 600 * time.Second
@@ -39,12 +34,12 @@ const defaultTTL = 600 * time.Second
var ErrClosed = errors.New("closed")
var (
errAllFailed = errors.New("all upstream nameservers failed")
errFullQueue = errors.New("request queue full")
errNoNameservers = errors.New("no upstream nameservers set")
errMapNotSet = errors.New("domain map not set")
errNotForwarding = errors.New("forwarding disabled")
errNotImplemented = errors.New("query type not implemented")
errNotQuery = errors.New("not a DNS query")
errNotOurName = errors.New("not a Tailscale DNS name")
)
// Packet represents a DNS payload together with the address of its origin.
@@ -63,57 +58,69 @@ type Packet struct {
// it delegates to upstream nameservers if any are set.
type Resolver struct {
logf logger.Logf
// The asynchronous interface is due to the fact that resolution may potentially
// block for a long time (if the upstream nameserver is slow to reach).
// rootDomain is <root> in <mynode>.<mydomain>.<root>.
rootDomain []byte
// forwarder is
forwarder *forwarder
// queue is a buffered channel holding DNS requests queued for resolution.
queue chan Packet
// responses is an unbuffered channel to which responses are sent.
// responses is an unbuffered channel to which responses are returned.
responses chan Packet
// errors is an unbuffered channel to which errors are sent.
// errors is an unbuffered channel to which errors are returned.
errors chan error
// closed notifies the poll goroutines to stop.
// closed signals all goroutines to stop.
closed chan struct{}
// pollGroup signals when all poll goroutines have stopped.
pollGroup sync.WaitGroup
// rootDomain is <root> in <mynode>.<mydomain>.<root>.
rootDomain []byte
// dialer is the netns.Dialer used for delegation.
dialer netns.Dialer
// wg signals when all goroutines have stopped.
wg sync.WaitGroup
// mu guards the following fields from being updated while used.
mu sync.Mutex
// dnsMap is the map most recently received from the control server.
dnsMap *Map
// nameservers is the list of nameserver addresses that should be used
// if the received query is not for a Tailscale node.
// The addresses are strings of the form ip:port, as expected by Dial.
nameservers []string
}
// ResolverConfig is the set of configuration options for a Resolver.
type ResolverConfig struct {
// Logf is the logger to use throughout the Resolver.
Logf logger.Logf
// RootDomain is the domain whose subdomains will be resolved locally as Tailscale nodes.
RootDomain string
// Forward determines whether the resolver will forward packets to
// nameservers set with SetUpstreams if the domain name is not of a Tailscale node.
Forward bool
}
// NewResolver constructs a resolver associated with the given root domain.
// The root domain must be in canonical form (with a trailing period).
func NewResolver(logf logger.Logf, rootDomain string) *Resolver {
func NewResolver(config ResolverConfig) *Resolver {
r := &Resolver{
logf: logger.WithPrefix(logf, "tsdns: "),
queue: make(chan Packet, queueSize),
logf: logger.WithPrefix(config.Logf, "tsdns: "),
queue: make(chan Packet, pendingQueueSize),
responses: make(chan Packet),
errors: make(chan error),
closed: make(chan struct{}),
rootDomain: []byte(rootDomain),
dialer: netns.NewDialer(),
rootDomain: []byte(config.RootDomain),
}
if config.Forward {
r.forwarder = newForwarder(r.logf, r.responses)
}
return r
}
func (r *Resolver) Start() {
// TODO(dmytro): spawn more than one goroutine? They block on delegation.
r.pollGroup.Add(1)
func (r *Resolver) Start() error {
if r.forwarder != nil {
if err := r.forwarder.Start(); err != nil {
return err
}
}
r.wg.Add(1)
go r.poll()
return nil
}
// Close shuts down the resolver and ensures poll goroutines have exited.
@@ -126,7 +133,12 @@ func (r *Resolver) Close() {
// continue
}
close(r.closed)
r.pollGroup.Wait()
if r.forwarder != nil {
r.forwarder.Close()
}
r.wg.Wait()
}
// SetMap sets the resolver's DNS map, taking ownership of it.
@@ -138,14 +150,12 @@ func (r *Resolver) SetMap(m *Map) {
r.logf("map diff:\n%s", m.PrettyDiffFrom(oldMap))
}
// SetUpstreamNameservers sets the addresses of the resolver's
// SetUpstreams sets the addresses of the resolver's
// upstream nameservers, taking ownership of the argument.
// The addresses should be strings of the form ip:port,
// matching what Dial("udp", addr) expects as addr.
func (r *Resolver) SetNameservers(nameservers []string) {
r.mu.Lock()
r.nameservers = nameservers
r.mu.Unlock()
func (r *Resolver) SetUpstreams(upstreams []net.Addr) {
if r.forwarder != nil {
r.forwarder.setUpstreams(upstreams)
}
}
// EnqueueRequest places the given DNS request in the resolver's queue.
@@ -153,6 +163,8 @@ func (r *Resolver) SetNameservers(nameservers []string) {
// If the queue is full, the request will be dropped and an error will be returned.
func (r *Resolver) EnqueueRequest(request Packet) error {
select {
case <-r.closed:
return ErrClosed
case r.queue <- request:
return nil
default:
@@ -164,12 +176,12 @@ func (r *Resolver) EnqueueRequest(request Packet) error {
// It blocks until a response is available and gives up ownership of the response payload.
func (r *Resolver) NextResponse() (Packet, error) {
select {
case <-r.closed:
return Packet{}, ErrClosed
case resp := <-r.responses:
return resp, nil
case err := <-r.errors:
return Packet{}, err
case <-r.closed:
return Packet{}, ErrClosed
}
}
@@ -209,114 +221,50 @@ func (r *Resolver) ResolveReverse(ip netaddr.IP) (string, dns.RCode, error) {
}
func (r *Resolver) poll() {
defer r.pollGroup.Done()
defer r.wg.Done()
var (
packet Packet
err error
)
var packet Packet
for {
select {
case packet = <-r.queue:
// continue
case <-r.closed:
return
case packet = <-r.queue:
// continue
}
out, err := r.respond(packet.Payload)
if err == errNotOurName {
if r.forwarder != nil {
err = r.forwarder.forward(packet)
if err == nil {
// forward will send response into r.responses, nothing to do.
continue
}
} else {
err = errNotForwarding
}
}
packet.Payload, err = r.respond(packet.Payload)
if err != nil {
select {
case <-r.closed:
return
case r.errors <- err:
// continue
case <-r.closed:
return
}
} else {
packet.Payload = out
select {
case r.responses <- packet:
// continue
case <-r.closed:
return
case r.responses <- packet:
// continue
}
}
}
}
// queryServer obtains a DNS response by querying the given server.
func (r *Resolver) queryServer(ctx context.Context, server string, query []byte) ([]byte, error) {
conn, err := r.dialer.DialContext(ctx, "udp", server)
if err != nil {
return nil, err
}
defer conn.Close()
// Interrupt the current operation when the context is cancelled.
go func() {
<-ctx.Done()
conn.SetDeadline(time.Unix(1, 0))
}()
_, err = conn.Write(query)
if err != nil {
return nil, err
}
out := make([]byte, maxResponseSize)
n, err := conn.Read(out)
if err != nil {
return nil, err
}
return out[:n], nil
}
// delegate forwards the query to all upstream nameservers and returns the first response.
func (r *Resolver) delegate(query []byte) ([]byte, error) {
r.mu.Lock()
nameservers := r.nameservers
r.mu.Unlock()
if len(nameservers) == 0 {
return nil, errNoNameservers
}
ctx, cancel := context.WithTimeout(context.Background(), delegateTimeout)
defer cancel()
// Common case, don't spawn goroutines.
if len(nameservers) == 1 {
return r.queryServer(ctx, nameservers[0], query)
}
datach := make(chan []byte)
for _, server := range nameservers {
go func(s string) {
resp, err := r.queryServer(ctx, s, query)
// Only print errors not due to cancelation after first response.
if err != nil && ctx.Err() != context.Canceled {
r.logf("querying %s: %v", s, err)
}
datach <- resp
}(server)
}
var response []byte
for range nameservers {
cur := <-datach
if cur != nil && response == nil {
// Received first successful response
response = cur
cancel()
}
}
if response == nil {
return nil, errAllFailed
}
return response, nil
}
type response struct {
Header dns.Header
Question dns.Question
@@ -517,8 +465,6 @@ func rdnsNameToIPv6(name []byte) (ip netaddr.IP, ok bool) {
func (r *Resolver) respondReverse(query []byte, resp *response) ([]byte, error) {
name := resp.Question.Name.Data[:resp.Question.Name.Length]
shouldDelegate := false
var ip netaddr.IP
var ok bool
var err error
@@ -528,7 +474,7 @@ func (r *Resolver) respondReverse(query []byte, resp *response) ([]byte, error)
case bytes.HasSuffix(name, rdnsv6Suffix):
ip, ok = rdnsNameToIPv6(name)
default:
shouldDelegate = true
return nil, errNotOurName
}
// It is more likely that we failed in parsing the name than that it is actually malformed.
@@ -536,25 +482,15 @@ func (r *Resolver) respondReverse(query []byte, resp *response) ([]byte, error)
if !ok {
// Without this conversion, escape analysis rules that resp escapes.
r.logf("parsing rdns: malformed name: %s", resp.Question.Name.String())
shouldDelegate = true
return nil, errNotOurName
}
if !shouldDelegate {
resp.Name, resp.Header.RCode, err = r.ResolveReverse(ip)
if err != nil {
r.logf("resolving rdns: %v", ip, err)
}
shouldDelegate = (resp.Header.RCode == dns.RCodeNameError)
resp.Name, resp.Header.RCode, err = r.ResolveReverse(ip)
if err != nil {
r.logf("resolving rdns: %v", ip, err)
}
if shouldDelegate {
out, err := r.delegate(query)
if err != nil {
r.logf("delegating rdns: %v", err)
resp.Header.RCode = dns.RCodeServerFailure
return marshalResponse(resp)
}
return out, nil
if resp.Header.RCode == dns.RCodeNameError {
return nil, errNotOurName
}
return marshalResponse(resp)
@@ -586,13 +522,7 @@ func (r *Resolver) respond(query []byte) ([]byte, error) {
// We do this on bytes because Name.String() allocates.
rawName := resp.Question.Name.Data[:resp.Question.Name.Length]
if !bytes.HasSuffix(rawName, r.rootDomain) {
out, err := r.delegate(query)
if err != nil {
r.logf("delegating: %v", err)
resp.Header.RCode = dns.RCodeServerFailure
return marshalResponse(resp)
}
return out, nil
return nil, errNotOurName
}
switch resp.Question.Type {