net/dns/resolver: teach the forwarder to do per-domain routing.

Given a DNS route map, the forwarder selects the right set of
upstreams for a given name.

Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
David Anderson
2021-04-01 19:31:55 -07:00
parent 4ed111281b
commit 9f105d3968
5 changed files with 180 additions and 122 deletions

View File

@@ -17,10 +17,12 @@ import (
"sync"
"time"
dns "golang.org/x/net/dns/dnsmessage"
"inet.af/netaddr"
"tailscale.com/logtail/backoff"
"tailscale.com/net/netns"
"tailscale.com/types/logger"
"tailscale.com/util/dnsname"
)
// headerBytes is the number of bytes in a DNS message header.
@@ -100,6 +102,11 @@ func getTxID(packet []byte) txid {
return (txid(hash) << 32) | txid(dnsid)
}
type route struct {
suffix string
resolvers []netaddr.IPPort
}
// forwarder forwards DNS packets to a number of upstream nameservers.
type forwarder struct {
logf logger.Logf
@@ -116,10 +123,9 @@ type forwarder struct {
conns []*fwdConn
mu sync.Mutex
// upstreams are the nameserver addresses that should be used for forwarding.
upstreams []net.Addr
// txMap maps DNS txids to active forwarding records.
txMap map[txid]forwardingRecord
// routes are per-suffix resolvers to use.
routes []route // most specific routes first
txMap map[txid]forwardingRecord // txids to in-flight requests
}
func init() {
@@ -127,24 +133,22 @@ func init() {
}
func newForwarder(logf logger.Logf, responses chan packet) *forwarder {
return &forwarder{
ret := &forwarder{
logf: logger.WithPrefix(logf, "forward: "),
responses: responses,
closed: make(chan struct{}),
conns: make([]*fwdConn, connCount),
txMap: make(map[txid]forwardingRecord),
}
}
func (f *forwarder) Start() error {
f.wg.Add(connCount + 1)
for idx := range f.conns {
f.conns[idx] = newFwdConn(f.logf, idx)
go f.recv(f.conns[idx])
ret.wg.Add(connCount + 1)
for idx := range ret.conns {
ret.conns[idx] = newFwdConn(ret.logf, idx)
go ret.recv(ret.conns[idx])
}
go f.cleanMap()
go ret.cleanMap()
return nil
return ret
}
func (f *forwarder) Close() {
@@ -171,14 +175,15 @@ func (f *forwarder) rebindFromNetworkChange() {
}
}
func (f *forwarder) setUpstreams(upstreams []net.Addr) {
func (f *forwarder) setRoutes(routes []route) {
fmt.Println(routes)
f.mu.Lock()
f.upstreams = upstreams
f.routes = routes
f.mu.Unlock()
}
// send sends packet to dst. It is best effort.
func (f *forwarder) send(packet []byte, dst net.Addr) {
func (f *forwarder) send(packet []byte, dst netaddr.IPPort) {
connIdx := rand.Intn(connCount)
conn := f.conns[connIdx]
conn.send(packet, dst)
@@ -256,24 +261,38 @@ func (f *forwarder) cleanMap() {
// forward forwards the query to all upstream nameservers and returns the first response.
func (f *forwarder) forward(query packet) error {
domain, err := nameFromQuery(query.bs)
if err != nil {
return err
}
txid := getTxID(query.bs)
f.mu.Lock()
routes := f.routes
f.mu.Unlock()
upstreams := f.upstreams
if len(upstreams) == 0 {
f.mu.Unlock()
var resolvers []netaddr.IPPort
for _, route := range routes {
if route.suffix != "." && !dnsname.HasSuffix(domain, route.suffix) {
continue
}
resolvers = route.resolvers
break
}
if len(resolvers) == 0 {
return errNoUpstreams
}
f.mu.Lock()
f.txMap[txid] = forwardingRecord{
src: query.addr,
createdAt: time.Now(),
}
f.mu.Unlock()
for _, upstream := range upstreams {
f.send(query.bs, upstream)
for _, resolver := range resolvers {
f.send(query.bs, resolver)
}
return nil
@@ -309,7 +328,7 @@ func newFwdConn(logf logger.Logf, idx int) *fwdConn {
// send sends packet to dst using c's connection.
// It is best effort. It is UDP, after all. Failures are logged.
func (c *fwdConn) send(packet []byte, dst net.Addr) {
func (c *fwdConn) send(packet []byte, dst netaddr.IPPort) {
var b *backoff.Backoff // lazily initialized, since it is not needed in the common case
backOff := func(err error) {
if b == nil {
@@ -335,8 +354,9 @@ func (c *fwdConn) send(packet []byte, dst net.Addr) {
}
c.mu.Unlock()
a := dst.UDPAddr()
c.wg.Add(1)
_, err := conn.WriteTo(packet, dst)
_, err := conn.WriteTo(packet, a)
c.wg.Done()
if err == nil {
// Success
@@ -469,3 +489,24 @@ func (c *fwdConn) close() {
// Unblock any remaining readers.
c.change.Broadcast()
}
// nameFromQuery extracts the normalized query name from bs.
func nameFromQuery(bs []byte) (string, error) {
var parser dns.Parser
hdr, err := parser.Start(bs)
if err != nil {
return "", err
}
if hdr.Response {
return "", errNotQuery
}
q, err := parser.Question()
if err != nil {
return "", err
}
n := q.Name.Data[:q.Name.Length]
return rawNameToLower(n), nil
}