net/dns{,/resolver}: refactor DNS forwarder, send out of right link on macOS/iOS

Fixes #2224
Fixes tailscale/corp#2045

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2021-06-22 21:53:43 -07:00 committed by Brad Fitzpatrick
parent 597fa3d3c3
commit 45e64f2e1a
10 changed files with 530 additions and 500 deletions

View File

@ -9,14 +9,12 @@
import (
"errors"
"fmt"
"log"
"net"
"strings"
"syscall"
"golang.org/x/sys/unix"
"inet.af/netaddr"
"tailscale.com/net/interfaces"
"tailscale.com/net/netns"
)
func init() {
@ -32,29 +30,7 @@ func initListenConfigNetworkExtension(nc *net.ListenConfig, ip netaddr.IP, st *i
if !ok {
return fmt.Errorf("no interface with name %q", tunIfName)
}
nc.Control = func(network, address string, c syscall.RawConn) error {
var sockErr error
err := c.Control(func(fd uintptr) {
sockErr = bindIf(fd, network, address, tunIf.Index)
log.Printf("peerapi: bind(%q, %q) on index %v = %v", network, address, tunIf.Index, sockErr)
})
if err != nil {
return err
}
return sockErr
}
return nil
}
func bindIf(fd uintptr, network, address string, ifIndex int) error {
v6 := strings.Contains(address, "]:") || strings.HasSuffix(network, "6") // hacky test for v6
proto := unix.IPPROTO_IP
opt := unix.IP_BOUND_IF
if v6 {
proto = unix.IPPROTO_IPV6
opt = unix.IPV6_BOUND_IF
}
return unix.SetsockoptInt(int(fd), proto, opt, ifIndex)
return netns.SetListenConfigInterfaceIndex(nc, tunIf.Index)
}
func peerDialControlFuncNetworkExtension(b *LocalBackend) func(network, address string, c syscall.RawConn) error {
@ -68,17 +44,12 @@ func peerDialControlFuncNetworkExtension(b *LocalBackend) func(network, address
index = tunIf.Index
}
}
var lc net.ListenConfig
netns.SetListenConfigInterfaceIndex(&lc, index)
return func(network, address string, c syscall.RawConn) error {
if index == -1 {
return errors.New("failed to find TUN interface to bind to")
}
var sockErr error
err := c.Control(func(fd uintptr) {
sockErr = bindIf(fd, network, address, index)
})
if err != nil {
return err
}
return sockErr
return lc.Control(network, address, c)
}
}

View File

@ -38,11 +38,11 @@ type Manager struct {
}
// NewManagers created a new manager from the given config.
func NewManager(logf logger.Logf, oscfg OSConfigurator, linkMon *monitor.Mon) *Manager {
func NewManager(logf logger.Logf, oscfg OSConfigurator, linkMon *monitor.Mon, linkSel resolver.ForwardLinkSelector) *Manager {
logf = logger.WithPrefix(logf, "dns: ")
m := &Manager{
logf: logf,
resolver: resolver.New(logf, linkMon),
resolver: resolver.New(logf, linkMon, linkSel),
os: oscfg,
}
m.logf("using %T", m.os)
@ -207,7 +207,7 @@ func Cleanup(logf logger.Logf, interfaceName string) {
logf("creating dns cleanup: %v", err)
return
}
dns := NewManager(logf, oscfg, nil)
dns := NewManager(logf, oscfg, nil, nil)
if err := dns.Down(); err != nil {
logf("dns down: %v", err)
}

View File

@ -376,7 +376,7 @@ func TestManager(t *testing.T) {
SplitDNS: test.split,
BaseConfig: test.bs,
}
m := NewManager(t.Logf, &f, nil)
m := NewManager(t.Logf, &f, nil, nil)
m.resolver.TestOnlySetHook(f.SetResolver)
if err := m.Set(test.in); err != nil {

View File

@ -9,41 +9,30 @@
"context"
"encoding/binary"
"errors"
"fmt"
"hash/crc32"
"io"
"math/rand"
"net"
"sync"
"syscall"
"time"
dns "golang.org/x/net/dns/dnsmessage"
"inet.af/netaddr"
"tailscale.com/logtail/backoff"
"tailscale.com/types/logger"
"tailscale.com/util/dnsname"
"tailscale.com/wgengine/monitor"
)
// headerBytes is the number of bytes in a DNS message header.
const headerBytes = 12
// connCount is the number of UDP connections to use for forwarding.
const connCount = 32
const (
// cleanupInterval is the interval between purged of timed-out entries from txMap.
cleanupInterval = 30 * time.Second
// responseTimeout is the maximal amount of time to wait for a DNS response.
responseTimeout = 5 * time.Second
)
var errNoUpstreams = errors.New("upstream nameservers not set")
type forwardingRecord struct {
src netaddr.IPPort
createdAt time.Time
}
// txid identifies a DNS transaction.
//
// As the standard DNS Request ID is only 16 bits, we extend it:
@ -100,178 +89,164 @@ func getTxID(packet []byte) txid {
}
type route struct {
suffix dnsname.FQDN
resolvers []netaddr.IPPort
Suffix dnsname.FQDN
Resolvers []netaddr.IPPort
}
// forwarder forwards DNS packets to a number of upstream nameservers.
type forwarder struct {
logf logger.Logf
logf logger.Logf
linkMon *monitor.Mon
linkSel ForwardLinkSelector
ctx context.Context // good until Close
ctxCancel context.CancelFunc // closes ctx
// responses is a channel by which responses are returned.
responses chan packet
// closed signals all goroutines to stop.
closed chan struct{}
// wg signals when all goroutines have stopped.
wg sync.WaitGroup
// conns are the UDP connections used for forwarding.
// A random one is selected for each request, regardless of the target upstream.
conns []*fwdConn
mu sync.Mutex // guards following
mu sync.Mutex
// routes are per-suffix resolvers to use.
routes []route // most specific routes first
txMap map[txid]forwardingRecord // txids to in-flight requests
// routes are per-suffix resolvers to use, with
// the most specific routes first.
routes []route
}
func init() {
rand.Seed(time.Now().UnixNano())
}
func newForwarder(logf logger.Logf, responses chan packet) *forwarder {
ret := &forwarder{
func newForwarder(logf logger.Logf, responses chan packet, linkMon *monitor.Mon, linkSel ForwardLinkSelector) *forwarder {
f := &forwarder{
logf: logger.WithPrefix(logf, "forward: "),
linkMon: linkMon,
linkSel: linkSel,
responses: responses,
closed: make(chan struct{}),
conns: make([]*fwdConn, connCount),
txMap: make(map[txid]forwardingRecord),
}
ret.wg.Add(connCount + 1)
for idx := range ret.conns {
ret.conns[idx] = newFwdConn(ret.logf, idx)
go ret.recv(ret.conns[idx])
}
go ret.cleanMap()
return ret
f.ctx, f.ctxCancel = context.WithCancel(context.Background())
return f
}
func (f *forwarder) Close() {
select {
case <-f.closed:
return
default:
// continue
}
close(f.closed)
for _, conn := range f.conns {
conn.close()
}
f.wg.Wait()
}
func (f *forwarder) rebindFromNetworkChange() {
for _, c := range f.conns {
c.mu.Lock()
c.reconnectLocked()
c.mu.Unlock()
}
func (f *forwarder) Close() error {
f.ctxCancel()
return nil
}
func (f *forwarder) setRoutes(routes []route) {
f.mu.Lock()
defer f.mu.Unlock()
f.routes = routes
f.mu.Unlock()
}
var stdNetPacketListener packetListener = new(net.ListenConfig)
type packetListener interface {
ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error)
}
func (f *forwarder) packetListener(ip netaddr.IP) (packetListener, error) {
if f.linkSel == nil || initListenConfig == nil {
return stdNetPacketListener, nil
}
linkName := f.linkSel.PickLink(ip)
if linkName == "" {
return stdNetPacketListener, nil
}
lc := new(net.ListenConfig)
if err := initListenConfig(lc, f.linkMon, linkName); err != nil {
return nil, err
}
return lc, nil
}
// send sends packet to dst. It is best effort.
func (f *forwarder) send(packet []byte, dst netaddr.IPPort) {
connIdx := rand.Intn(connCount)
conn := f.conns[connIdx]
conn.send(packet, dst)
}
//
// send expects the reply to have the same txid as txidOut.
//
// The provided closeOnCtxDone lets send register values to Close if
// the caller's ctx expires. This avoids send from allocating its own
// waiting goroutine to interrupt the ReadFrom, as memory is tight on
// iOS and we want the number of pending DNS lookups to be bursty
// without too much associated goroutine/memory cost.
func (f *forwarder) send(ctx context.Context, txidOut txid, closeOnCtxDone *closePool, packet []byte, dst netaddr.IPPort) ([]byte, error) {
// TODO(bradfitz): if dst.IP is 8.8.8.8 or 8.8.4.4 or 1.1.1.1, etc, or
// something dynamically probed earlier to support DoH or DoT,
// do that here instead.
func (f *forwarder) recv(conn *fwdConn) {
defer f.wg.Done()
ln, err := f.packetListener(dst.IP())
if err != nil {
return nil, err
}
conn, err := ln.ListenPacket(ctx, "udp", ":0")
if err != nil {
f.logf("ListenPacket failed: %v", err)
return nil, err
}
defer conn.Close()
for {
select {
case <-f.closed:
return
default:
closeOnCtxDone.Add(conn)
defer closeOnCtxDone.Remove(conn)
if _, err := conn.WriteTo(packet, dst.UDPAddr()); err != nil {
if err := ctx.Err(); err != nil {
return nil, err
}
// The 1 extra byte is to detect packet truncation.
out := make([]byte, maxResponseBytes+1)
n := conn.read(out)
var truncated bool
if n > maxResponseBytes {
n = maxResponseBytes
truncated = true
return nil, err
}
// The 1 extra byte is to detect packet truncation.
out := make([]byte, maxResponseBytes+1)
n, _, err := conn.ReadFrom(out)
if err != nil {
if err := ctx.Err(); err != nil {
return nil, err
}
if n == 0 {
continue
}
if n < headerBytes {
f.logf("recv: packet too small (%d bytes)", n)
}
out = out[:n]
txid := getTxID(out)
if truncated {
const dnsFlagTruncated = 0x200
flags := binary.BigEndian.Uint16(out[2:4])
flags |= dnsFlagTruncated
binary.BigEndian.PutUint16(out[2:4], flags)
// TODO(#2067): Remove any incomplete records? RFC 1035 section 6.2
// states that truncation should head drop so that the authority
// section can be preserved if possible. However, the UDP read with
// a too-small buffer has already dropped the end, so that's the
// best we can do.
}
f.mu.Lock()
record, found := f.txMap[txid]
// At most one nameserver will return a response:
// the first one to do so will delete txid from the map.
if !found {
f.mu.Unlock()
continue
}
delete(f.txMap, txid)
f.mu.Unlock()
pkt := packet{out, record.src}
select {
case <-f.closed:
return
case f.responses <- pkt:
// continue
if packetWasTruncated(err) {
err = nil
} else {
return nil, err
}
}
truncated := n > maxResponseBytes
if truncated {
n = maxResponseBytes
}
if n < headerBytes {
f.logf("recv: packet too small (%d bytes)", n)
}
out = out[:n]
txid := getTxID(out)
if txid != txidOut {
return nil, errors.New("txid doesn't match")
}
if truncated {
const dnsFlagTruncated = 0x200
flags := binary.BigEndian.Uint16(out[2:4])
flags |= dnsFlagTruncated
binary.BigEndian.PutUint16(out[2:4], flags)
// TODO(#2067): Remove any incomplete records? RFC 1035 section 6.2
// states that truncation should head drop so that the authority
// section can be preserved if possible. However, the UDP read with
// a too-small buffer has already dropped the end, so that's the
// best we can do.
}
return out, nil
}
// cleanMap periodically deletes timed-out forwarding records from f.txMap to bound growth.
func (f *forwarder) cleanMap() {
defer f.wg.Done()
t := time.NewTicker(cleanupInterval)
defer t.Stop()
var now time.Time
for {
select {
case <-f.closed:
return
case now = <-t.C:
// continue
// resolvers returns the resolvers to use for domain.
func (f *forwarder) resolvers(domain dnsname.FQDN) []netaddr.IPPort {
f.mu.Lock()
routes := f.routes
f.mu.Unlock()
for _, route := range routes {
if route.Suffix == "." || route.Suffix.Contains(domain) {
return route.Resolvers
}
f.mu.Lock()
for k, v := range f.txMap {
if now.Sub(v.createdAt) > responseTimeout {
delete(f.txMap, k)
}
}
f.mu.Unlock()
}
return nil
}
// forward forwards the query to all upstream nameservers and returns the first response.
@ -283,225 +258,60 @@ func (f *forwarder) forward(query packet) error {
txid := getTxID(query.bs)
f.mu.Lock()
routes := f.routes
f.mu.Unlock()
var resolvers []netaddr.IPPort
for _, route := range routes {
if route.suffix != "." && !route.suffix.Contains(domain) {
continue
}
resolvers = route.resolvers
break
}
resolvers := f.resolvers(domain)
if len(resolvers) == 0 {
return errNoUpstreams
}
f.mu.Lock()
f.txMap[txid] = forwardingRecord{
src: query.addr,
createdAt: time.Now(),
}
f.mu.Unlock()
closeOnCtxDone := new(closePool)
defer closeOnCtxDone.Close()
// TODO(#2066): EDNS size clamping
ctx, cancel := context.WithTimeout(f.ctx, responseTimeout)
defer cancel()
for _, resolver := range resolvers {
f.send(query.bs, resolver)
}
resc := make(chan []byte, 1)
var (
mu sync.Mutex
firstErr error
)
return nil
}
// A fwdConn manages a single connection used to forward DNS requests.
// Net link changes can cause a *net.UDPConn to become permanently unusable, particularly on macOS.
// fwdConn detects such situations and transparently creates new connections.
type fwdConn struct {
// logf allows a fwdConn to log.
logf logger.Logf
// change allows calls to read to block until a the network connection has been replaced.
change *sync.Cond
// mu protects fields that follow it; it is also change's Locker.
mu sync.Mutex
// closed tracks whether fwdConn has been permanently closed.
closed bool
// conn is the current active connection.
conn net.PacketConn
}
func newFwdConn(logf logger.Logf, idx int) *fwdConn {
c := new(fwdConn)
c.logf = logger.WithPrefix(logf, fmt.Sprintf("fwdConn %d: ", idx))
c.change = sync.NewCond(&c.mu)
// c.conn is created lazily in send
return c
}
// send sends packet to dst using c's connection.
// It is best effort. It is UDP, after all. Failures are logged.
func (c *fwdConn) send(packet []byte, dst netaddr.IPPort) {
var b *backoff.Backoff // lazily initialized, since it is not needed in the common case
backOff := func(err error) {
if b == nil {
b = backoff.NewBackoff("dns-fwdConn-send", c.logf, 30*time.Second)
}
b.BackOff(context.Background(), err)
}
for {
// Gather the current connection.
// We can't hold the lock while we call WriteTo.
c.mu.Lock()
conn := c.conn
closed := c.closed
if closed {
c.mu.Unlock()
return
}
if conn == nil {
c.reconnectLocked()
c.mu.Unlock()
continue
}
c.mu.Unlock()
_, err := conn.WriteTo(packet, dst.UDPAddr())
if err == nil {
// Success
return
}
if errors.Is(err, net.ErrClosed) {
// We intentionally closed this connection.
// It has been replaced by a new connection. Try again.
continue
}
// Something else went wrong.
// We have three choices here: try again, give up, or create a new connection.
var opErr *net.OpError
if !errors.As(err, &opErr) {
// Weird. All errors from the net package should be *net.OpError. Bail.
c.logf("send: non-*net.OpErr %v (%T)", err, err)
return
}
if opErr.Temporary() || opErr.Timeout() {
// I doubt that either of these can happen (this is UDP),
// but go ahead and try again.
backOff(err)
continue
}
if errors.Is(err, syscall.EHOSTUNREACH) {
// "No route to host." The network stack is fine, but
// can't talk to this destination. Not much we can do
// about that, don't spam logs.
return
}
if networkIsDown(err) {
// Fail.
c.logf("send: network is down")
return
}
if networkIsUnreachable(err) {
// This can be caused by a link change.
// Replace the existing connection with a new one.
c.mu.Lock()
// It's possible that multiple senders discovered simultaneously
// that the network is unreachable. Avoid reconnecting multiple times:
// Only reconnect if the current connection is the one that we
// discovered to be problematic.
if c.conn == conn {
backOff(err)
c.reconnectLocked()
for _, ipp := range resolvers {
go func(ipp netaddr.IPPort) {
resb, err := f.send(ctx, txid, closeOnCtxDone, query.bs, ipp)
if err != nil {
mu.Lock()
defer mu.Unlock()
if firstErr == nil {
firstErr = err
}
return
}
c.mu.Unlock()
// Try again with our new network connection.
continue
select {
case resc <- resb:
default:
}
}(ipp)
}
select {
case v := <-resc:
select {
case <-ctx.Done():
return ctx.Err()
case f.responses <- packet{v, query.addr}:
return nil
}
// Unrecognized error. Fail.
c.logf("send: unrecognized error: %v", err)
return
case <-ctx.Done():
mu.Lock()
defer mu.Unlock()
if firstErr != nil {
return firstErr
}
return ctx.Err()
}
}
// read waits for a response from c's connection.
// It returns the number of bytes read, which may be 0
// in case of an error or a closed connection.
func (c *fwdConn) read(out []byte) int {
for {
// Gather the current connection.
// We can't hold the lock while we call ReadFrom.
c.mu.Lock()
conn := c.conn
closed := c.closed
if closed {
c.mu.Unlock()
return 0
}
if conn == nil {
// There is no current connection.
// Wait for the connection to change, then try again.
c.change.Wait()
c.mu.Unlock()
continue
}
c.mu.Unlock()
n, _, err := conn.ReadFrom(out)
if err == nil || packetWasTruncated(err) {
// Success.
return n
}
if errors.Is(err, net.ErrClosed) {
// We intentionally closed this connection.
// It has been replaced by a new connection. Try again.
continue
}
c.logf("read: unrecognized error: %v", err)
return 0
}
}
// reconnectLocked replaces the current connection with a new one.
// c.mu must be locked.
func (c *fwdConn) reconnectLocked() {
c.closeConnLocked()
// Make a new connection.
conn, err := net.ListenPacket("udp", "")
if err != nil {
c.logf("ListenPacket failed: %v", err)
} else {
c.conn = conn
}
// Broadcast that a new connection is available.
c.change.Broadcast()
}
// closeCurrentConn closes the current connection.
// c.mu must be locked.
func (c *fwdConn) closeConnLocked() {
if c.conn == nil {
return
}
c.conn.Close() // unblocks all readers/writers, they'll pick up the next connection.
c.conn = nil
}
// close permanently closes c.
func (c *fwdConn) close() {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return
}
c.closed = true
c.closeConnLocked()
// Unblock any remaining readers.
c.change.Broadcast()
}
var initListenConfig func(_ *net.ListenConfig, _ *monitor.Mon, tunName string) error
// nameFromQuery extracts the normalized query name from bs.
func nameFromQuery(bs []byte) (dnsname.FQDN, error) {
@ -523,3 +333,48 @@ func nameFromQuery(bs []byte) (dnsname.FQDN, error) {
n := q.Name.Data[:q.Name.Length]
return dnsname.ToFQDN(rawNameToLower(n))
}
// closePool is a dynamic set of io.Closers to close as a group.
// It's intended to be Closed at most once.
//
// The zero value is ready for use.
type closePool struct {
mu sync.Mutex
m map[io.Closer]bool
closed bool
}
func (p *closePool) Add(c io.Closer) {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
c.Close()
return
}
if p.m == nil {
p.m = map[io.Closer]bool{}
}
p.m[c] = true
}
func (p *closePool) Remove(c io.Closer) {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return
}
delete(p.m, c)
}
func (p *closePool) Close() error {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return nil
}
p.closed = true
for c := range p.m {
c.Close()
}
return nil
}

View File

@ -0,0 +1,27 @@
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin,ts_macext ios,ts_macext
package resolver
import (
"errors"
"net"
"tailscale.com/net/netns"
"tailscale.com/wgengine/monitor"
)
func init() {
initListenConfig = initListenConfigNetworkExtension
}
func initListenConfigNetworkExtension(nc *net.ListenConfig, mon *monitor.Mon, tunName string) error {
nif, ok := mon.InterfaceState().Interface[tunName]
if !ok {
return errors.New("utun not found")
}
return netns.SetListenConfigInterfaceIndex(nc, nif.Interface.Index)
}

View File

@ -9,14 +9,15 @@
import (
"encoding/hex"
"errors"
"runtime"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
dns "golang.org/x/net/dns/dnsmessage"
"inet.af/netaddr"
"tailscale.com/net/interfaces"
"tailscale.com/types/logger"
"tailscale.com/util/dnsname"
"tailscale.com/wgengine/monitor"
@ -27,10 +28,20 @@
// truncation in a platform-agnostic way.
const maxResponseBytes = 4095
// queueSize is the maximal number of DNS requests that can await polling.
// maxActiveQueries returns the maximal number of DNS requests that be
// can running.
// If EnqueueRequest is called when this many requests are already pending,
// the request will be dropped to avoid blocking the caller.
const queueSize = 64
func maxActiveQueries() int32 {
if runtime.GOOS == "ios" {
// For memory paranoia reasons on iOS, match the
// historical Tailscale 1.x..1.8 behavior for now
// (just before the 1.10 release).
return 64
}
// But for other platforms, allow more burstiness:
return 256
}
// defaultTTL is the TTL of all responses from Resolver.
const defaultTTL = 600 * time.Second
@ -75,13 +86,12 @@ type Config struct {
type Resolver struct {
logf logger.Logf
linkMon *monitor.Mon // or nil
unregLinkMon func() // or nil
saveConfigForTests func(cfg Config) // used in tests to capture resolver config
// forwarder forwards requests to upstream nameservers.
forwarder *forwarder
// queue is a buffered channel holding DNS requests queued for resolution.
queue chan packet
activeQueriesAtomic int32 // number of DNS queries in flight
// responses is an unbuffered channel to which responses are returned.
responses chan packet
// errors is an unbuffered channel to which errors are returned.
@ -98,27 +108,26 @@ type Resolver struct {
ipToHost map[netaddr.IP]dnsname.FQDN
}
type ForwardLinkSelector interface {
// PickLink returns which network device should be used to query
// the DNS server at the given IP.
// The empty string means to use an unspecified default.
PickLink(netaddr.IP) (linkName string)
}
// New returns a new resolver.
// linkMon optionally specifies a link monitor to use for socket rebinding.
func New(logf logger.Logf, linkMon *monitor.Mon) *Resolver {
func New(logf logger.Logf, linkMon *monitor.Mon, linkSel ForwardLinkSelector) *Resolver {
r := &Resolver{
logf: logger.WithPrefix(logf, "dns: "),
linkMon: linkMon,
queue: make(chan packet, queueSize),
responses: make(chan packet),
errors: make(chan error),
closed: make(chan struct{}),
hostToIP: map[dnsname.FQDN][]netaddr.IP{},
ipToHost: map[netaddr.IP]dnsname.FQDN{},
}
r.forwarder = newForwarder(r.logf, r.responses)
if r.linkMon != nil {
r.unregLinkMon = r.linkMon.RegisterChangeCallback(r.onLinkMonitorChange)
}
r.wg.Add(1)
go r.poll()
r.forwarder = newForwarder(r.logf, r.responses, linkMon, linkSel)
return r
}
@ -140,13 +149,13 @@ func (r *Resolver) SetConfig(cfg Config) error {
for suffix, ips := range cfg.Routes {
routes = append(routes, route{
suffix: suffix,
resolvers: ips,
Suffix: suffix,
Resolvers: ips,
})
}
// Sort from longest prefix to shortest.
sort.Slice(routes, func(i, j int) bool {
return routes[i].suffix.NumLabels() > routes[j].suffix.NumLabels()
return routes[i].Suffix.NumLabels() > routes[j].Suffix.NumLabels()
})
r.forwarder.setRoutes(routes)
@ -170,19 +179,7 @@ func (r *Resolver) Close() {
}
close(r.closed)
if r.unregLinkMon != nil {
r.unregLinkMon()
}
r.forwarder.Close()
r.wg.Wait()
}
func (r *Resolver) onLinkMonitorChange(changed bool, state *interfaces.State) {
if !changed {
return
}
r.forwarder.rebindFromNetworkChange()
}
// EnqueueRequest places the given DNS request in the resolver's queue.
@ -192,11 +189,14 @@ func (r *Resolver) EnqueueRequest(bs []byte, from netaddr.IPPort) error {
select {
case <-r.closed:
return ErrClosed
case r.queue <- packet{bs, from}:
return nil
default:
}
if n := atomic.AddInt32(&r.activeQueriesAtomic, 1); n > maxActiveQueries() {
atomic.AddInt32(&r.activeQueriesAtomic, -1)
return errFullQueue
}
go r.handleQuery(packet{bs, from})
return nil
}
// NextResponse returns a DNS response to a previously enqueued request.
@ -291,53 +291,34 @@ func (r *Resolver) resolveLocal(domain dnsname.FQDN, typ dns.Type) (netaddr.IP,
// resolveReverse returns the unique domain name that maps to the given address.
func (r *Resolver) resolveLocalReverse(ip netaddr.IP) (dnsname.FQDN, dns.RCode) {
r.mu.Lock()
ips := r.ipToHost
r.mu.Unlock()
name, found := ips[ip]
if !found {
defer r.mu.Unlock()
name, ok := r.ipToHost[ip]
if !ok {
return "", dns.RCodeNameError
}
return name, dns.RCodeSuccess
}
func (r *Resolver) poll() {
defer r.wg.Done()
func (r *Resolver) handleQuery(pkt packet) {
defer atomic.AddInt32(&r.activeQueriesAtomic, -1)
var pkt packet
for {
out, err := r.respond(pkt.bs)
if err == errNotOurName {
err = r.forwarder.forward(pkt)
if err == nil {
// forward will send response into r.responses, nothing to do.
return
}
}
if err != nil {
select {
case <-r.closed:
return
case pkt = <-r.queue:
// continue
case r.errors <- err:
}
out, err := r.respond(pkt.bs)
if err == errNotOurName {
err = r.forwarder.forward(pkt)
if err == nil {
// forward will send response into r.responses, nothing to do.
continue
}
}
if err != nil {
select {
case <-r.closed:
return
case r.errors <- err:
// continue
}
} else {
pkt.bs = out
select {
case <-r.closed:
return
case r.responses <- pkt:
// continue
}
} else {
select {
case <-r.closed:
case r.responses <- packet{out, pkt.addr}:
}
}
}
@ -351,28 +332,44 @@ type response struct {
IP netaddr.IP
}
// parseQuery parses the query in given packet into a response struct.
// if the parse is successful, resp.Name contains the normalized name being queried.
// TODO: stuffing the query name in resp.Name temporarily is a hack. Clean it up.
func parseQuery(query []byte, resp *response) error {
var parser dns.Parser
var err error
var dnsParserPool = &sync.Pool{
New: func() interface{} {
return new(dnsParser)
},
}
resp.Header, err = parser.Start(query)
// dnsParser parses DNS queries using x/net/dns/dnsmessage.
// These structs are pooled with dnsParserPool.
type dnsParser struct {
Header dns.Header
Question dns.Question
parser dns.Parser
}
func (p *dnsParser) response() *response {
return &response{Header: p.Header, Question: p.Question}
}
// zeroParser clears parser so it doesn't retain its most recently
// parsed DNS query's []byte while it's sitting in a sync.Pool.
// It's not useful to keep anyway: the next Start will do the same.
func (p *dnsParser) zeroParser() { p.parser = dns.Parser{} }
// parseQuery parses the query in given packet into p.Header and
// p.Question.
func (p *dnsParser) parseQuery(query []byte) error {
defer p.zeroParser()
var err error
p.Header, err = p.parser.Start(query)
if err != nil {
return err
}
if resp.Header.Response {
if p.Header.Response {
return errNotQuery
}
resp.Question, err = parser.Question()
if err != nil {
return err
}
return nil
p.Question, err = p.parser.Question()
return err
}
// marshalARecord serializes an A record into an active builder.
@ -624,12 +621,13 @@ func (r *Resolver) respondReverse(query []byte, name dnsname.FQDN, resp *respons
// respond returns a DNS response to query if it can be resolved locally.
// Otherwise, it returns errNotOurName.
func (r *Resolver) respond(query []byte) ([]byte, error) {
resp := new(response)
parser := dnsParserPool.Get().(*dnsParser)
defer dnsParserPool.Put(parser)
// ParseQuery is sufficiently fast to run on every DNS packet.
// This is considerably simpler than extracting the name by hand
// to shave off microseconds in case of delegation.
err := parseQuery(query, resp)
err := parser.parseQuery(query)
// We will not return this error: it is the sender's fault.
if err != nil {
if errors.Is(err, dns.ErrSectionDone) {
@ -637,13 +635,15 @@ func (r *Resolver) respond(query []byte) ([]byte, error) {
} else {
r.logf("parseQuery(%02x): %v", query, err)
}
resp := parser.response()
resp.Header.RCode = dns.RCodeFormatError
return marshalResponse(resp)
}
rawName := resp.Question.Name.Data[:resp.Question.Name.Length]
rawName := parser.Question.Name.Data[:parser.Question.Name.Length]
name, err := dnsname.ToFQDN(rawNameToLower(rawName))
if err != nil {
// DNS packet unexpectedly contains an invalid FQDN.
resp := parser.response()
resp.Header.RCode = dns.RCodeFormatError
return marshalResponse(resp)
}
@ -651,15 +651,17 @@ func (r *Resolver) respond(query []byte) ([]byte, error) {
// Always try to handle reverse lookups; delegate inside when not found.
// This way, queries for existent nodes do not leak,
// but we behave gracefully if non-Tailscale nodes exist in CGNATRange.
if resp.Question.Type == dns.TypePTR {
return r.respondReverse(query, name, resp)
if parser.Question.Type == dns.TypePTR {
return r.respondReverse(query, name, parser.response())
}
resp.IP, resp.Header.RCode = r.resolveLocal(name, resp.Question.Type)
// This return code is special: it requests forwarding.
if resp.Header.RCode == dns.RCodeRefused {
return nil, errNotOurName
ip, rcode := r.resolveLocal(name, parser.Question.Type)
if rcode == dns.RCodeRefused {
return nil, errNotOurName // sentinel error return value: it requests forwarding
}
resp := parser.response()
resp.Header.RCode = rcode
resp.IP = ip
return marshalResponse(resp)
}

View File

@ -8,6 +8,7 @@
"bytes"
"encoding/hex"
"errors"
"fmt"
"math/rand"
"net"
"runtime"
@ -17,6 +18,7 @@
"inet.af/netaddr"
"tailscale.com/tstest"
"tailscale.com/util/dnsname"
"tailscale.com/wgengine/monitor"
)
var testipv4 = netaddr.MustParseIP("1.2.3.4")
@ -128,7 +130,9 @@ func unpackResponse(payload []byte) (dnsResponse, error) {
}
func syncRespond(r *Resolver, query []byte) ([]byte, error) {
r.EnqueueRequest(query, netaddr.IPPort{})
if err := r.EnqueueRequest(query, netaddr.IPPort{}); err != nil {
return nil, fmt.Errorf("EnqueueRequest: %w", err)
}
payload, _, err := r.NextResponse()
return payload, err
}
@ -211,8 +215,12 @@ func TestRDNSNameToIPv6(t *testing.T) {
}
}
func newResolver(t testing.TB) *Resolver {
return New(t.Logf, nil /* no link monitor */, nil /* no link selector */)
}
func TestResolveLocal(t *testing.T) {
r := New(t.Logf, nil)
r := newResolver(t)
defer r.Close()
r.SetConfig(dnsCfg)
@ -252,7 +260,7 @@ func TestResolveLocal(t *testing.T) {
}
func TestResolveLocalReverse(t *testing.T) {
r := New(t.Logf, nil)
r := newResolver(t)
defer r.Close()
r.SetConfig(dnsCfg)
@ -362,7 +370,7 @@ func TestDelegate(t *testing.T) {
"huge.txt.", resolveToTXT(hugeTXT))
defer v6server.Shutdown()
r := New(t.Logf, nil)
r := newResolver(t)
defer r.Close()
cfg := dnsCfg
@ -474,7 +482,7 @@ func TestDelegateSplitRoute(t *testing.T) {
"test.other.", resolveToIP(test4, test6, "dns.other."))
defer server2.Shutdown()
r := New(t.Logf, nil)
r := newResolver(t)
defer r.Close()
cfg := dnsCfg
@ -531,7 +539,7 @@ func TestDelegateCollision(t *testing.T) {
"test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
defer server.Shutdown()
r := New(t.Logf, nil)
r := newResolver(t)
defer r.Close()
cfg := dnsCfg
@ -745,7 +753,7 @@ func TestDelegateCollision(t *testing.T) {
}
func TestFull(t *testing.T) {
r := New(t.Logf, nil)
r := newResolver(t)
defer r.Close()
r.SetConfig(dnsCfg)
@ -781,7 +789,7 @@ func TestFull(t *testing.T) {
}
func TestAllocs(t *testing.T) {
r := New(t.Logf, nil)
r := newResolver(t)
defer r.Close()
r.SetConfig(dnsCfg)
@ -835,7 +843,7 @@ func BenchmarkFull(b *testing.B) {
"test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
defer server.Shutdown()
r := New(b.Logf, nil)
r := newResolver(b)
defer r.Close()
cfg := dnsCfg
@ -872,3 +880,58 @@ func TestMarshalResponseFormatError(t *testing.T) {
}
t.Logf("response: %q", v)
}
func TestForwardLinkSelection(t *testing.T) {
old := initListenConfig
defer func() { initListenConfig = old }()
configCall := make(chan string, 1)
initListenConfig = func(nc *net.ListenConfig, mon *monitor.Mon, tunName string) error {
select {
case configCall <- tunName:
return nil
default:
t.Error("buffer full")
return errors.New("buffer full")
}
}
// specialIP is some IP we pretend that our link selector
// routes differently.
specialIP := netaddr.IPv4(1, 2, 3, 4)
fwd := newForwarder(t.Logf, nil, nil, linkSelFunc(func(ip netaddr.IP) string {
if ip == netaddr.IPv4(1, 2, 3, 4) {
return "special"
}
return ""
}))
// Test non-special IP.
if got, err := fwd.packetListener(netaddr.IP{}); err != nil {
t.Fatal(err)
} else if got != stdNetPacketListener {
t.Errorf("for IP zero value, didn't get expected packet listener")
}
select {
case v := <-configCall:
t.Errorf("unexpected ListenConfig call, with tunName %q", v)
default:
}
// Test that our special IP generates a call to initListenConfig.
if got, err := fwd.packetListener(specialIP); err != nil {
t.Fatal(err)
} else if got == stdNetPacketListener {
t.Errorf("special IP returned std packet listener; expected unique one")
}
if v, ok := <-configCall; !ok {
t.Errorf("didn't get ListenConfig call")
} else if v != "special" {
t.Errorf("got tunName %q; want 'special'", v)
}
}
type linkSelFunc func(ip netaddr.IP) string
func (f linkSelFunc) PickLink(ip netaddr.IP) string { return f(ip) }

53
net/netns/netns_macios.go Normal file
View File

@ -0,0 +1,53 @@
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin ios
package netns
import (
"errors"
"log"
"net"
"strings"
"syscall"
"golang.org/x/sys/unix"
)
// SetListenConfigInterfaceIndex sets lc.Control such that sockets are bound
// to the provided interface index.
func SetListenConfigInterfaceIndex(lc *net.ListenConfig, ifIndex int) error {
if lc == nil {
return errors.New("nil ListenConfig")
}
if lc.Control != nil {
return errors.New("ListenConfig.Control already set")
}
lc.Control = func(network, address string, c syscall.RawConn) error {
var sockErr error
err := c.Control(func(fd uintptr) {
sockErr = bindInterface(fd, network, address, ifIndex)
if sockErr != nil {
log.Printf("netns: bind(%q, %q) on index %v: %v", network, address, ifIndex, sockErr)
}
})
if err != nil {
return err
}
return sockErr
}
return nil
}
func bindInterface(fd uintptr, network, address string, ifIndex int) error {
v6 := strings.Contains(address, "]:") || strings.HasSuffix(network, "6") // hacky test for v6
proto := unix.IPPROTO_IP
opt := unix.IP_BOUND_IF
if v6 {
proto = unix.IPPROTO_IPV6
opt = unix.IPV6_BOUND_IF
}
return unix.SetsockoptInt(int(fd), proto, opt, ifIndex)
}

View File

@ -15,19 +15,19 @@
)
func ResourceCheck(tb testing.TB) {
tb.Helper()
startN, startStacks := goroutines()
tb.Cleanup(func() {
if tb.Failed() {
// Something else went wrong.
return
}
tb.Helper()
// Goroutines might be still exiting.
for i := 0; i < 100; i++ {
if runtime.NumGoroutine() <= startN {
return
}
time.Sleep(1 * time.Millisecond)
time.Sleep(5 * time.Millisecond)
}
endN, endStacks := goroutines()
tb.Logf("goroutine diff:\n%v\n", cmp.Diff(startStacks, endStacks))

View File

@ -101,6 +101,10 @@ type userspaceEngine struct {
// incorrectly sent to us.
isLocalAddr atomic.Value // of func(netaddr.IP)bool
// isDNSIPOverTailscale reports the whether a DNS resolver's IP
// is being routed over Tailscale.
isDNSIPOverTailscale atomic.Value // of func(netaddr.IP)bool
wgLock sync.Mutex // serializes all wgdev operations; see lock order comment below
lastCfgFull wgcfg.Config
lastRouterSig string // of router.Config
@ -242,6 +246,7 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
confListenPort: conf.ListenPort,
}
e.isLocalAddr.Store(tsaddr.NewContainsIPFunc(nil))
e.isDNSIPOverTailscale.Store(tsaddr.NewContainsIPFunc(nil))
if conf.LinkMonitor != nil {
e.linkMon = conf.LinkMonitor
@ -255,7 +260,8 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
e.linkMonOwned = true
}
e.dns = dns.NewManager(logf, conf.DNS, e.linkMon)
tunName, _ := conf.Tun.Name()
e.dns = dns.NewManager(logf, conf.DNS, e.linkMon, fwdDNSLinkSelector{e, tunName})
logf("link state: %+v", e.linkMon.InterfaceState())
@ -767,6 +773,13 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config,
return ErrNoChanges
}
// TODO(bradfitz,danderson): maybe delete this isDNSIPOverTailscale
// field and delete the resolver.ForwardLinkSelector hook and
// instead have ipnlocal populate a map of DNS IP => linkName and
// put that in the *dns.Config instead, and plumb it down to the
// dns.Manager. Maybe also with isLocalAddr above.
e.isDNSIPOverTailscale.Store(tsaddr.NewContainsIPFunc(dnsIPsOverTailscale(dnsCfg, routerCfg)))
// See if any peers have changed disco keys, which means they've restarted.
// If so, we need to update the wireguard-go/device.Device in two phases:
// once without the node which has restarted, to clear its wireguard session key,
@ -1362,3 +1375,49 @@ func (p closeOnErrorPool) closeAllIfError(errp *error) {
}
}
}
// ipInPrefixes reports whether ip is in any of pp.
func ipInPrefixes(ip netaddr.IP, pp []netaddr.IPPrefix) bool {
for _, p := range pp {
if p.Contains(ip) {
return true
}
}
return false
}
// dnsIPsOverTailscale returns the IPPrefixes of DNS resolver IPs that are
// routed over Tailscale. The returned value does not contain duplicates is
// not necessarily sorted.
func dnsIPsOverTailscale(dnsCfg *dns.Config, routerCfg *router.Config) (ret []netaddr.IPPrefix) {
m := map[netaddr.IP]bool{}
for _, resolvers := range dnsCfg.Routes {
for _, resolver := range resolvers {
ip := resolver.IP()
if ipInPrefixes(ip, routerCfg.Routes) && !ipInPrefixes(ip, routerCfg.LocalRoutes) {
m[ip] = true
}
}
}
ret = make([]netaddr.IPPrefix, 0, len(m))
for ip := range m {
ret = append(ret, netaddr.IPPrefixFrom(ip, ip.BitLen()))
}
return ret
}
// fwdDNSLinkSelector is userspaceEngine's resolver.ForwardLinkSelector, to pick
// which network interface to send DNS queries out of.
type fwdDNSLinkSelector struct {
ue *userspaceEngine
tunName string
}
func (ls fwdDNSLinkSelector) PickLink(ip netaddr.IP) (linkName string) {
if ls.ue.isDNSIPOverTailscale.Load().(func(netaddr.IP) bool)(ip) {
return ls.tunName
}
return ""
}