mirror of
https://github.com/tailscale/tailscale.git
synced 2025-12-08 14:09:41 +00:00
net/dns/resolver: teach the forwarder to do per-domain routing.
Given a DNS route map, the forwarder selects the right set of upstreams for a given name. Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
@@ -17,10 +17,12 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
dns "golang.org/x/net/dns/dnsmessage"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/logtail/backoff"
|
||||
"tailscale.com/net/netns"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/dnsname"
|
||||
)
|
||||
|
||||
// headerBytes is the number of bytes in a DNS message header.
|
||||
@@ -100,6 +102,11 @@ func getTxID(packet []byte) txid {
|
||||
return (txid(hash) << 32) | txid(dnsid)
|
||||
}
|
||||
|
||||
type route struct {
|
||||
suffix string
|
||||
resolvers []netaddr.IPPort
|
||||
}
|
||||
|
||||
// forwarder forwards DNS packets to a number of upstream nameservers.
|
||||
type forwarder struct {
|
||||
logf logger.Logf
|
||||
@@ -116,10 +123,9 @@ type forwarder struct {
|
||||
conns []*fwdConn
|
||||
|
||||
mu sync.Mutex
|
||||
// upstreams are the nameserver addresses that should be used for forwarding.
|
||||
upstreams []net.Addr
|
||||
// txMap maps DNS txids to active forwarding records.
|
||||
txMap map[txid]forwardingRecord
|
||||
// routes are per-suffix resolvers to use.
|
||||
routes []route // most specific routes first
|
||||
txMap map[txid]forwardingRecord // txids to in-flight requests
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -127,24 +133,22 @@ func init() {
|
||||
}
|
||||
|
||||
func newForwarder(logf logger.Logf, responses chan packet) *forwarder {
|
||||
return &forwarder{
|
||||
ret := &forwarder{
|
||||
logf: logger.WithPrefix(logf, "forward: "),
|
||||
responses: responses,
|
||||
closed: make(chan struct{}),
|
||||
conns: make([]*fwdConn, connCount),
|
||||
txMap: make(map[txid]forwardingRecord),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *forwarder) Start() error {
|
||||
f.wg.Add(connCount + 1)
|
||||
for idx := range f.conns {
|
||||
f.conns[idx] = newFwdConn(f.logf, idx)
|
||||
go f.recv(f.conns[idx])
|
||||
ret.wg.Add(connCount + 1)
|
||||
for idx := range ret.conns {
|
||||
ret.conns[idx] = newFwdConn(ret.logf, idx)
|
||||
go ret.recv(ret.conns[idx])
|
||||
}
|
||||
go f.cleanMap()
|
||||
go ret.cleanMap()
|
||||
|
||||
return nil
|
||||
return ret
|
||||
}
|
||||
|
||||
func (f *forwarder) Close() {
|
||||
@@ -171,14 +175,15 @@ func (f *forwarder) rebindFromNetworkChange() {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *forwarder) setUpstreams(upstreams []net.Addr) {
|
||||
func (f *forwarder) setRoutes(routes []route) {
|
||||
fmt.Println(routes)
|
||||
f.mu.Lock()
|
||||
f.upstreams = upstreams
|
||||
f.routes = routes
|
||||
f.mu.Unlock()
|
||||
}
|
||||
|
||||
// send sends packet to dst. It is best effort.
|
||||
func (f *forwarder) send(packet []byte, dst net.Addr) {
|
||||
func (f *forwarder) send(packet []byte, dst netaddr.IPPort) {
|
||||
connIdx := rand.Intn(connCount)
|
||||
conn := f.conns[connIdx]
|
||||
conn.send(packet, dst)
|
||||
@@ -256,24 +261,38 @@ func (f *forwarder) cleanMap() {
|
||||
|
||||
// forward forwards the query to all upstream nameservers and returns the first response.
|
||||
func (f *forwarder) forward(query packet) error {
|
||||
domain, err := nameFromQuery(query.bs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
txid := getTxID(query.bs)
|
||||
|
||||
f.mu.Lock()
|
||||
routes := f.routes
|
||||
f.mu.Unlock()
|
||||
|
||||
upstreams := f.upstreams
|
||||
if len(upstreams) == 0 {
|
||||
f.mu.Unlock()
|
||||
var resolvers []netaddr.IPPort
|
||||
for _, route := range routes {
|
||||
if route.suffix != "." && !dnsname.HasSuffix(domain, route.suffix) {
|
||||
continue
|
||||
}
|
||||
resolvers = route.resolvers
|
||||
break
|
||||
}
|
||||
if len(resolvers) == 0 {
|
||||
return errNoUpstreams
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
f.txMap[txid] = forwardingRecord{
|
||||
src: query.addr,
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
|
||||
f.mu.Unlock()
|
||||
|
||||
for _, upstream := range upstreams {
|
||||
f.send(query.bs, upstream)
|
||||
for _, resolver := range resolvers {
|
||||
f.send(query.bs, resolver)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -309,7 +328,7 @@ func newFwdConn(logf logger.Logf, idx int) *fwdConn {
|
||||
|
||||
// send sends packet to dst using c's connection.
|
||||
// It is best effort. It is UDP, after all. Failures are logged.
|
||||
func (c *fwdConn) send(packet []byte, dst net.Addr) {
|
||||
func (c *fwdConn) send(packet []byte, dst netaddr.IPPort) {
|
||||
var b *backoff.Backoff // lazily initialized, since it is not needed in the common case
|
||||
backOff := func(err error) {
|
||||
if b == nil {
|
||||
@@ -335,8 +354,9 @@ func (c *fwdConn) send(packet []byte, dst net.Addr) {
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
a := dst.UDPAddr()
|
||||
c.wg.Add(1)
|
||||
_, err := conn.WriteTo(packet, dst)
|
||||
_, err := conn.WriteTo(packet, a)
|
||||
c.wg.Done()
|
||||
if err == nil {
|
||||
// Success
|
||||
@@ -469,3 +489,24 @@ func (c *fwdConn) close() {
|
||||
// Unblock any remaining readers.
|
||||
c.change.Broadcast()
|
||||
}
|
||||
|
||||
// nameFromQuery extracts the normalized query name from bs.
|
||||
func nameFromQuery(bs []byte) (string, error) {
|
||||
var parser dns.Parser
|
||||
|
||||
hdr, err := parser.Start(bs)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if hdr.Response {
|
||||
return "", errNotQuery
|
||||
}
|
||||
|
||||
q, err := parser.Question()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
n := q.Name.Data[:q.Name.Length]
|
||||
return rawNameToLower(n), nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user