mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-29 04:55:31 +00:00
net/dns{,/resolver}: refactor DNS forwarder, send out of right link on macOS/iOS
Fixes #2224 Fixes tailscale/corp#2045 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
597fa3d3c3
commit
45e64f2e1a
@ -9,14 +9,12 @@
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/net/interfaces"
|
||||
"tailscale.com/net/netns"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -32,29 +30,7 @@ func initListenConfigNetworkExtension(nc *net.ListenConfig, ip netaddr.IP, st *i
|
||||
if !ok {
|
||||
return fmt.Errorf("no interface with name %q", tunIfName)
|
||||
}
|
||||
nc.Control = func(network, address string, c syscall.RawConn) error {
|
||||
var sockErr error
|
||||
err := c.Control(func(fd uintptr) {
|
||||
sockErr = bindIf(fd, network, address, tunIf.Index)
|
||||
log.Printf("peerapi: bind(%q, %q) on index %v = %v", network, address, tunIf.Index, sockErr)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sockErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func bindIf(fd uintptr, network, address string, ifIndex int) error {
|
||||
v6 := strings.Contains(address, "]:") || strings.HasSuffix(network, "6") // hacky test for v6
|
||||
proto := unix.IPPROTO_IP
|
||||
opt := unix.IP_BOUND_IF
|
||||
if v6 {
|
||||
proto = unix.IPPROTO_IPV6
|
||||
opt = unix.IPV6_BOUND_IF
|
||||
}
|
||||
return unix.SetsockoptInt(int(fd), proto, opt, ifIndex)
|
||||
return netns.SetListenConfigInterfaceIndex(nc, tunIf.Index)
|
||||
}
|
||||
|
||||
func peerDialControlFuncNetworkExtension(b *LocalBackend) func(network, address string, c syscall.RawConn) error {
|
||||
@ -68,17 +44,12 @@ func peerDialControlFuncNetworkExtension(b *LocalBackend) func(network, address
|
||||
index = tunIf.Index
|
||||
}
|
||||
}
|
||||
var lc net.ListenConfig
|
||||
netns.SetListenConfigInterfaceIndex(&lc, index)
|
||||
return func(network, address string, c syscall.RawConn) error {
|
||||
if index == -1 {
|
||||
return errors.New("failed to find TUN interface to bind to")
|
||||
}
|
||||
var sockErr error
|
||||
err := c.Control(func(fd uintptr) {
|
||||
sockErr = bindIf(fd, network, address, index)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sockErr
|
||||
return lc.Control(network, address, c)
|
||||
}
|
||||
}
|
||||
|
@ -38,11 +38,11 @@ type Manager struct {
|
||||
}
|
||||
|
||||
// NewManagers created a new manager from the given config.
|
||||
func NewManager(logf logger.Logf, oscfg OSConfigurator, linkMon *monitor.Mon) *Manager {
|
||||
func NewManager(logf logger.Logf, oscfg OSConfigurator, linkMon *monitor.Mon, linkSel resolver.ForwardLinkSelector) *Manager {
|
||||
logf = logger.WithPrefix(logf, "dns: ")
|
||||
m := &Manager{
|
||||
logf: logf,
|
||||
resolver: resolver.New(logf, linkMon),
|
||||
resolver: resolver.New(logf, linkMon, linkSel),
|
||||
os: oscfg,
|
||||
}
|
||||
m.logf("using %T", m.os)
|
||||
@ -207,7 +207,7 @@ func Cleanup(logf logger.Logf, interfaceName string) {
|
||||
logf("creating dns cleanup: %v", err)
|
||||
return
|
||||
}
|
||||
dns := NewManager(logf, oscfg, nil)
|
||||
dns := NewManager(logf, oscfg, nil, nil)
|
||||
if err := dns.Down(); err != nil {
|
||||
logf("dns down: %v", err)
|
||||
}
|
||||
|
@ -376,7 +376,7 @@ func TestManager(t *testing.T) {
|
||||
SplitDNS: test.split,
|
||||
BaseConfig: test.bs,
|
||||
}
|
||||
m := NewManager(t.Logf, &f, nil)
|
||||
m := NewManager(t.Logf, &f, nil, nil)
|
||||
m.resolver.TestOnlySetHook(f.SetResolver)
|
||||
|
||||
if err := m.Set(test.in); err != nil {
|
||||
|
@ -9,41 +9,30 @@
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
dns "golang.org/x/net/dns/dnsmessage"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/logtail/backoff"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/dnsname"
|
||||
"tailscale.com/wgengine/monitor"
|
||||
)
|
||||
|
||||
// headerBytes is the number of bytes in a DNS message header.
|
||||
const headerBytes = 12
|
||||
|
||||
// connCount is the number of UDP connections to use for forwarding.
|
||||
const connCount = 32
|
||||
|
||||
const (
|
||||
// cleanupInterval is the interval between purged of timed-out entries from txMap.
|
||||
cleanupInterval = 30 * time.Second
|
||||
// responseTimeout is the maximal amount of time to wait for a DNS response.
|
||||
responseTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
var errNoUpstreams = errors.New("upstream nameservers not set")
|
||||
|
||||
type forwardingRecord struct {
|
||||
src netaddr.IPPort
|
||||
createdAt time.Time
|
||||
}
|
||||
|
||||
// txid identifies a DNS transaction.
|
||||
//
|
||||
// As the standard DNS Request ID is only 16 bits, we extend it:
|
||||
@ -100,178 +89,164 @@ func getTxID(packet []byte) txid {
|
||||
}
|
||||
|
||||
type route struct {
|
||||
suffix dnsname.FQDN
|
||||
resolvers []netaddr.IPPort
|
||||
Suffix dnsname.FQDN
|
||||
Resolvers []netaddr.IPPort
|
||||
}
|
||||
|
||||
// forwarder forwards DNS packets to a number of upstream nameservers.
|
||||
type forwarder struct {
|
||||
logf logger.Logf
|
||||
logf logger.Logf
|
||||
linkMon *monitor.Mon
|
||||
linkSel ForwardLinkSelector
|
||||
|
||||
ctx context.Context // good until Close
|
||||
ctxCancel context.CancelFunc // closes ctx
|
||||
|
||||
// responses is a channel by which responses are returned.
|
||||
responses chan packet
|
||||
// closed signals all goroutines to stop.
|
||||
closed chan struct{}
|
||||
// wg signals when all goroutines have stopped.
|
||||
wg sync.WaitGroup
|
||||
|
||||
// conns are the UDP connections used for forwarding.
|
||||
// A random one is selected for each request, regardless of the target upstream.
|
||||
conns []*fwdConn
|
||||
mu sync.Mutex // guards following
|
||||
|
||||
mu sync.Mutex
|
||||
// routes are per-suffix resolvers to use.
|
||||
routes []route // most specific routes first
|
||||
txMap map[txid]forwardingRecord // txids to in-flight requests
|
||||
// routes are per-suffix resolvers to use, with
|
||||
// the most specific routes first.
|
||||
routes []route
|
||||
}
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func newForwarder(logf logger.Logf, responses chan packet) *forwarder {
|
||||
ret := &forwarder{
|
||||
func newForwarder(logf logger.Logf, responses chan packet, linkMon *monitor.Mon, linkSel ForwardLinkSelector) *forwarder {
|
||||
f := &forwarder{
|
||||
logf: logger.WithPrefix(logf, "forward: "),
|
||||
linkMon: linkMon,
|
||||
linkSel: linkSel,
|
||||
responses: responses,
|
||||
closed: make(chan struct{}),
|
||||
conns: make([]*fwdConn, connCount),
|
||||
txMap: make(map[txid]forwardingRecord),
|
||||
}
|
||||
|
||||
ret.wg.Add(connCount + 1)
|
||||
for idx := range ret.conns {
|
||||
ret.conns[idx] = newFwdConn(ret.logf, idx)
|
||||
go ret.recv(ret.conns[idx])
|
||||
}
|
||||
go ret.cleanMap()
|
||||
|
||||
return ret
|
||||
f.ctx, f.ctxCancel = context.WithCancel(context.Background())
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *forwarder) Close() {
|
||||
select {
|
||||
case <-f.closed:
|
||||
return
|
||||
default:
|
||||
// continue
|
||||
}
|
||||
close(f.closed)
|
||||
|
||||
for _, conn := range f.conns {
|
||||
conn.close()
|
||||
}
|
||||
|
||||
f.wg.Wait()
|
||||
}
|
||||
|
||||
func (f *forwarder) rebindFromNetworkChange() {
|
||||
for _, c := range f.conns {
|
||||
c.mu.Lock()
|
||||
c.reconnectLocked()
|
||||
c.mu.Unlock()
|
||||
}
|
||||
func (f *forwarder) Close() error {
|
||||
f.ctxCancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *forwarder) setRoutes(routes []route) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.routes = routes
|
||||
f.mu.Unlock()
|
||||
}
|
||||
|
||||
var stdNetPacketListener packetListener = new(net.ListenConfig)
|
||||
|
||||
type packetListener interface {
|
||||
ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error)
|
||||
}
|
||||
|
||||
func (f *forwarder) packetListener(ip netaddr.IP) (packetListener, error) {
|
||||
if f.linkSel == nil || initListenConfig == nil {
|
||||
return stdNetPacketListener, nil
|
||||
}
|
||||
linkName := f.linkSel.PickLink(ip)
|
||||
if linkName == "" {
|
||||
return stdNetPacketListener, nil
|
||||
}
|
||||
lc := new(net.ListenConfig)
|
||||
if err := initListenConfig(lc, f.linkMon, linkName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return lc, nil
|
||||
}
|
||||
|
||||
// send sends packet to dst. It is best effort.
|
||||
func (f *forwarder) send(packet []byte, dst netaddr.IPPort) {
|
||||
connIdx := rand.Intn(connCount)
|
||||
conn := f.conns[connIdx]
|
||||
conn.send(packet, dst)
|
||||
}
|
||||
//
|
||||
// send expects the reply to have the same txid as txidOut.
|
||||
//
|
||||
// The provided closeOnCtxDone lets send register values to Close if
|
||||
// the caller's ctx expires. This avoids send from allocating its own
|
||||
// waiting goroutine to interrupt the ReadFrom, as memory is tight on
|
||||
// iOS and we want the number of pending DNS lookups to be bursty
|
||||
// without too much associated goroutine/memory cost.
|
||||
func (f *forwarder) send(ctx context.Context, txidOut txid, closeOnCtxDone *closePool, packet []byte, dst netaddr.IPPort) ([]byte, error) {
|
||||
// TODO(bradfitz): if dst.IP is 8.8.8.8 or 8.8.4.4 or 1.1.1.1, etc, or
|
||||
// something dynamically probed earlier to support DoH or DoT,
|
||||
// do that here instead.
|
||||
|
||||
func (f *forwarder) recv(conn *fwdConn) {
|
||||
defer f.wg.Done()
|
||||
ln, err := f.packetListener(dst.IP())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn, err := ln.ListenPacket(ctx, "udp", ":0")
|
||||
if err != nil {
|
||||
f.logf("ListenPacket failed: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-f.closed:
|
||||
return
|
||||
default:
|
||||
closeOnCtxDone.Add(conn)
|
||||
defer closeOnCtxDone.Remove(conn)
|
||||
|
||||
if _, err := conn.WriteTo(packet, dst.UDPAddr()); err != nil {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// The 1 extra byte is to detect packet truncation.
|
||||
out := make([]byte, maxResponseBytes+1)
|
||||
n := conn.read(out)
|
||||
var truncated bool
|
||||
if n > maxResponseBytes {
|
||||
n = maxResponseBytes
|
||||
truncated = true
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// The 1 extra byte is to detect packet truncation.
|
||||
out := make([]byte, maxResponseBytes+1)
|
||||
n, _, err := conn.ReadFrom(out)
|
||||
if err != nil {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
if n < headerBytes {
|
||||
f.logf("recv: packet too small (%d bytes)", n)
|
||||
}
|
||||
|
||||
out = out[:n]
|
||||
txid := getTxID(out)
|
||||
|
||||
if truncated {
|
||||
const dnsFlagTruncated = 0x200
|
||||
flags := binary.BigEndian.Uint16(out[2:4])
|
||||
flags |= dnsFlagTruncated
|
||||
binary.BigEndian.PutUint16(out[2:4], flags)
|
||||
|
||||
// TODO(#2067): Remove any incomplete records? RFC 1035 section 6.2
|
||||
// states that truncation should head drop so that the authority
|
||||
// section can be preserved if possible. However, the UDP read with
|
||||
// a too-small buffer has already dropped the end, so that's the
|
||||
// best we can do.
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
|
||||
record, found := f.txMap[txid]
|
||||
// At most one nameserver will return a response:
|
||||
// the first one to do so will delete txid from the map.
|
||||
if !found {
|
||||
f.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
delete(f.txMap, txid)
|
||||
|
||||
f.mu.Unlock()
|
||||
|
||||
pkt := packet{out, record.src}
|
||||
select {
|
||||
case <-f.closed:
|
||||
return
|
||||
case f.responses <- pkt:
|
||||
// continue
|
||||
if packetWasTruncated(err) {
|
||||
err = nil
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
truncated := n > maxResponseBytes
|
||||
if truncated {
|
||||
n = maxResponseBytes
|
||||
}
|
||||
if n < headerBytes {
|
||||
f.logf("recv: packet too small (%d bytes)", n)
|
||||
}
|
||||
out = out[:n]
|
||||
txid := getTxID(out)
|
||||
if txid != txidOut {
|
||||
return nil, errors.New("txid doesn't match")
|
||||
}
|
||||
|
||||
if truncated {
|
||||
const dnsFlagTruncated = 0x200
|
||||
flags := binary.BigEndian.Uint16(out[2:4])
|
||||
flags |= dnsFlagTruncated
|
||||
binary.BigEndian.PutUint16(out[2:4], flags)
|
||||
|
||||
// TODO(#2067): Remove any incomplete records? RFC 1035 section 6.2
|
||||
// states that truncation should head drop so that the authority
|
||||
// section can be preserved if possible. However, the UDP read with
|
||||
// a too-small buffer has already dropped the end, so that's the
|
||||
// best we can do.
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// cleanMap periodically deletes timed-out forwarding records from f.txMap to bound growth.
|
||||
func (f *forwarder) cleanMap() {
|
||||
defer f.wg.Done()
|
||||
|
||||
t := time.NewTicker(cleanupInterval)
|
||||
defer t.Stop()
|
||||
|
||||
var now time.Time
|
||||
for {
|
||||
select {
|
||||
case <-f.closed:
|
||||
return
|
||||
case now = <-t.C:
|
||||
// continue
|
||||
// resolvers returns the resolvers to use for domain.
|
||||
func (f *forwarder) resolvers(domain dnsname.FQDN) []netaddr.IPPort {
|
||||
f.mu.Lock()
|
||||
routes := f.routes
|
||||
f.mu.Unlock()
|
||||
for _, route := range routes {
|
||||
if route.Suffix == "." || route.Suffix.Contains(domain) {
|
||||
return route.Resolvers
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
for k, v := range f.txMap {
|
||||
if now.Sub(v.createdAt) > responseTimeout {
|
||||
delete(f.txMap, k)
|
||||
}
|
||||
}
|
||||
f.mu.Unlock()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// forward forwards the query to all upstream nameservers and returns the first response.
|
||||
@ -283,225 +258,60 @@ func (f *forwarder) forward(query packet) error {
|
||||
|
||||
txid := getTxID(query.bs)
|
||||
|
||||
f.mu.Lock()
|
||||
routes := f.routes
|
||||
f.mu.Unlock()
|
||||
|
||||
var resolvers []netaddr.IPPort
|
||||
for _, route := range routes {
|
||||
if route.suffix != "." && !route.suffix.Contains(domain) {
|
||||
continue
|
||||
}
|
||||
resolvers = route.resolvers
|
||||
break
|
||||
}
|
||||
resolvers := f.resolvers(domain)
|
||||
if len(resolvers) == 0 {
|
||||
return errNoUpstreams
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
f.txMap[txid] = forwardingRecord{
|
||||
src: query.addr,
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
f.mu.Unlock()
|
||||
closeOnCtxDone := new(closePool)
|
||||
defer closeOnCtxDone.Close()
|
||||
|
||||
// TODO(#2066): EDNS size clamping
|
||||
ctx, cancel := context.WithTimeout(f.ctx, responseTimeout)
|
||||
defer cancel()
|
||||
|
||||
for _, resolver := range resolvers {
|
||||
f.send(query.bs, resolver)
|
||||
}
|
||||
resc := make(chan []byte, 1)
|
||||
var (
|
||||
mu sync.Mutex
|
||||
firstErr error
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// A fwdConn manages a single connection used to forward DNS requests.
|
||||
// Net link changes can cause a *net.UDPConn to become permanently unusable, particularly on macOS.
|
||||
// fwdConn detects such situations and transparently creates new connections.
|
||||
type fwdConn struct {
|
||||
// logf allows a fwdConn to log.
|
||||
logf logger.Logf
|
||||
|
||||
// change allows calls to read to block until a the network connection has been replaced.
|
||||
change *sync.Cond
|
||||
|
||||
// mu protects fields that follow it; it is also change's Locker.
|
||||
mu sync.Mutex
|
||||
// closed tracks whether fwdConn has been permanently closed.
|
||||
closed bool
|
||||
// conn is the current active connection.
|
||||
conn net.PacketConn
|
||||
}
|
||||
|
||||
func newFwdConn(logf logger.Logf, idx int) *fwdConn {
|
||||
c := new(fwdConn)
|
||||
c.logf = logger.WithPrefix(logf, fmt.Sprintf("fwdConn %d: ", idx))
|
||||
c.change = sync.NewCond(&c.mu)
|
||||
// c.conn is created lazily in send
|
||||
return c
|
||||
}
|
||||
|
||||
// send sends packet to dst using c's connection.
|
||||
// It is best effort. It is UDP, after all. Failures are logged.
|
||||
func (c *fwdConn) send(packet []byte, dst netaddr.IPPort) {
|
||||
var b *backoff.Backoff // lazily initialized, since it is not needed in the common case
|
||||
backOff := func(err error) {
|
||||
if b == nil {
|
||||
b = backoff.NewBackoff("dns-fwdConn-send", c.logf, 30*time.Second)
|
||||
}
|
||||
b.BackOff(context.Background(), err)
|
||||
}
|
||||
|
||||
for {
|
||||
// Gather the current connection.
|
||||
// We can't hold the lock while we call WriteTo.
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
closed := c.closed
|
||||
if closed {
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
if conn == nil {
|
||||
c.reconnectLocked()
|
||||
c.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
_, err := conn.WriteTo(packet, dst.UDPAddr())
|
||||
if err == nil {
|
||||
// Success
|
||||
return
|
||||
}
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
// We intentionally closed this connection.
|
||||
// It has been replaced by a new connection. Try again.
|
||||
continue
|
||||
}
|
||||
// Something else went wrong.
|
||||
// We have three choices here: try again, give up, or create a new connection.
|
||||
var opErr *net.OpError
|
||||
if !errors.As(err, &opErr) {
|
||||
// Weird. All errors from the net package should be *net.OpError. Bail.
|
||||
c.logf("send: non-*net.OpErr %v (%T)", err, err)
|
||||
return
|
||||
}
|
||||
if opErr.Temporary() || opErr.Timeout() {
|
||||
// I doubt that either of these can happen (this is UDP),
|
||||
// but go ahead and try again.
|
||||
backOff(err)
|
||||
continue
|
||||
}
|
||||
if errors.Is(err, syscall.EHOSTUNREACH) {
|
||||
// "No route to host." The network stack is fine, but
|
||||
// can't talk to this destination. Not much we can do
|
||||
// about that, don't spam logs.
|
||||
return
|
||||
}
|
||||
if networkIsDown(err) {
|
||||
// Fail.
|
||||
c.logf("send: network is down")
|
||||
return
|
||||
}
|
||||
if networkIsUnreachable(err) {
|
||||
// This can be caused by a link change.
|
||||
// Replace the existing connection with a new one.
|
||||
c.mu.Lock()
|
||||
// It's possible that multiple senders discovered simultaneously
|
||||
// that the network is unreachable. Avoid reconnecting multiple times:
|
||||
// Only reconnect if the current connection is the one that we
|
||||
// discovered to be problematic.
|
||||
if c.conn == conn {
|
||||
backOff(err)
|
||||
c.reconnectLocked()
|
||||
for _, ipp := range resolvers {
|
||||
go func(ipp netaddr.IPPort) {
|
||||
resb, err := f.send(ctx, txid, closeOnCtxDone, query.bs, ipp)
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
return
|
||||
}
|
||||
c.mu.Unlock()
|
||||
// Try again with our new network connection.
|
||||
continue
|
||||
select {
|
||||
case resc <- resb:
|
||||
default:
|
||||
}
|
||||
}(ipp)
|
||||
}
|
||||
|
||||
select {
|
||||
case v := <-resc:
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case f.responses <- packet{v, query.addr}:
|
||||
return nil
|
||||
}
|
||||
// Unrecognized error. Fail.
|
||||
c.logf("send: unrecognized error: %v", err)
|
||||
return
|
||||
case <-ctx.Done():
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if firstErr != nil {
|
||||
return firstErr
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// read waits for a response from c's connection.
|
||||
// It returns the number of bytes read, which may be 0
|
||||
// in case of an error or a closed connection.
|
||||
func (c *fwdConn) read(out []byte) int {
|
||||
for {
|
||||
// Gather the current connection.
|
||||
// We can't hold the lock while we call ReadFrom.
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
closed := c.closed
|
||||
if closed {
|
||||
c.mu.Unlock()
|
||||
return 0
|
||||
}
|
||||
if conn == nil {
|
||||
// There is no current connection.
|
||||
// Wait for the connection to change, then try again.
|
||||
c.change.Wait()
|
||||
c.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
n, _, err := conn.ReadFrom(out)
|
||||
if err == nil || packetWasTruncated(err) {
|
||||
// Success.
|
||||
return n
|
||||
}
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
// We intentionally closed this connection.
|
||||
// It has been replaced by a new connection. Try again.
|
||||
continue
|
||||
}
|
||||
|
||||
c.logf("read: unrecognized error: %v", err)
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// reconnectLocked replaces the current connection with a new one.
|
||||
// c.mu must be locked.
|
||||
func (c *fwdConn) reconnectLocked() {
|
||||
c.closeConnLocked()
|
||||
// Make a new connection.
|
||||
conn, err := net.ListenPacket("udp", "")
|
||||
if err != nil {
|
||||
c.logf("ListenPacket failed: %v", err)
|
||||
} else {
|
||||
c.conn = conn
|
||||
}
|
||||
// Broadcast that a new connection is available.
|
||||
c.change.Broadcast()
|
||||
}
|
||||
|
||||
// closeCurrentConn closes the current connection.
|
||||
// c.mu must be locked.
|
||||
func (c *fwdConn) closeConnLocked() {
|
||||
if c.conn == nil {
|
||||
return
|
||||
}
|
||||
c.conn.Close() // unblocks all readers/writers, they'll pick up the next connection.
|
||||
c.conn = nil
|
||||
}
|
||||
|
||||
// close permanently closes c.
|
||||
func (c *fwdConn) close() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.closed {
|
||||
return
|
||||
}
|
||||
c.closed = true
|
||||
c.closeConnLocked()
|
||||
// Unblock any remaining readers.
|
||||
c.change.Broadcast()
|
||||
}
|
||||
var initListenConfig func(_ *net.ListenConfig, _ *monitor.Mon, tunName string) error
|
||||
|
||||
// nameFromQuery extracts the normalized query name from bs.
|
||||
func nameFromQuery(bs []byte) (dnsname.FQDN, error) {
|
||||
@ -523,3 +333,48 @@ func nameFromQuery(bs []byte) (dnsname.FQDN, error) {
|
||||
n := q.Name.Data[:q.Name.Length]
|
||||
return dnsname.ToFQDN(rawNameToLower(n))
|
||||
}
|
||||
|
||||
// closePool is a dynamic set of io.Closers to close as a group.
|
||||
// It's intended to be Closed at most once.
|
||||
//
|
||||
// The zero value is ready for use.
|
||||
type closePool struct {
|
||||
mu sync.Mutex
|
||||
m map[io.Closer]bool
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (p *closePool) Add(c io.Closer) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.closed {
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
if p.m == nil {
|
||||
p.m = map[io.Closer]bool{}
|
||||
}
|
||||
p.m[c] = true
|
||||
}
|
||||
|
||||
func (p *closePool) Remove(c io.Closer) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.closed {
|
||||
return
|
||||
}
|
||||
delete(p.m, c)
|
||||
}
|
||||
|
||||
func (p *closePool) Close() error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.closed {
|
||||
return nil
|
||||
}
|
||||
p.closed = true
|
||||
for c := range p.m {
|
||||
c.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
27
net/dns/resolver/macios_ext.go
Normal file
27
net/dns/resolver/macios_ext.go
Normal file
@ -0,0 +1,27 @@
|
||||
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build darwin,ts_macext ios,ts_macext
|
||||
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"tailscale.com/net/netns"
|
||||
"tailscale.com/wgengine/monitor"
|
||||
)
|
||||
|
||||
func init() {
|
||||
initListenConfig = initListenConfigNetworkExtension
|
||||
}
|
||||
|
||||
func initListenConfigNetworkExtension(nc *net.ListenConfig, mon *monitor.Mon, tunName string) error {
|
||||
nif, ok := mon.InterfaceState().Interface[tunName]
|
||||
if !ok {
|
||||
return errors.New("utun not found")
|
||||
}
|
||||
return netns.SetListenConfigInterfaceIndex(nc, nif.Interface.Index)
|
||||
}
|
@ -9,14 +9,15 @@
|
||||
import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
dns "golang.org/x/net/dns/dnsmessage"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/net/interfaces"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/dnsname"
|
||||
"tailscale.com/wgengine/monitor"
|
||||
@ -27,10 +28,20 @@
|
||||
// truncation in a platform-agnostic way.
|
||||
const maxResponseBytes = 4095
|
||||
|
||||
// queueSize is the maximal number of DNS requests that can await polling.
|
||||
// maxActiveQueries returns the maximal number of DNS requests that be
|
||||
// can running.
|
||||
// If EnqueueRequest is called when this many requests are already pending,
|
||||
// the request will be dropped to avoid blocking the caller.
|
||||
const queueSize = 64
|
||||
func maxActiveQueries() int32 {
|
||||
if runtime.GOOS == "ios" {
|
||||
// For memory paranoia reasons on iOS, match the
|
||||
// historical Tailscale 1.x..1.8 behavior for now
|
||||
// (just before the 1.10 release).
|
||||
return 64
|
||||
}
|
||||
// But for other platforms, allow more burstiness:
|
||||
return 256
|
||||
}
|
||||
|
||||
// defaultTTL is the TTL of all responses from Resolver.
|
||||
const defaultTTL = 600 * time.Second
|
||||
@ -75,13 +86,12 @@ type Config struct {
|
||||
type Resolver struct {
|
||||
logf logger.Logf
|
||||
linkMon *monitor.Mon // or nil
|
||||
unregLinkMon func() // or nil
|
||||
saveConfigForTests func(cfg Config) // used in tests to capture resolver config
|
||||
// forwarder forwards requests to upstream nameservers.
|
||||
forwarder *forwarder
|
||||
|
||||
// queue is a buffered channel holding DNS requests queued for resolution.
|
||||
queue chan packet
|
||||
activeQueriesAtomic int32 // number of DNS queries in flight
|
||||
|
||||
// responses is an unbuffered channel to which responses are returned.
|
||||
responses chan packet
|
||||
// errors is an unbuffered channel to which errors are returned.
|
||||
@ -98,27 +108,26 @@ type Resolver struct {
|
||||
ipToHost map[netaddr.IP]dnsname.FQDN
|
||||
}
|
||||
|
||||
type ForwardLinkSelector interface {
|
||||
// PickLink returns which network device should be used to query
|
||||
// the DNS server at the given IP.
|
||||
// The empty string means to use an unspecified default.
|
||||
PickLink(netaddr.IP) (linkName string)
|
||||
}
|
||||
|
||||
// New returns a new resolver.
|
||||
// linkMon optionally specifies a link monitor to use for socket rebinding.
|
||||
func New(logf logger.Logf, linkMon *monitor.Mon) *Resolver {
|
||||
func New(logf logger.Logf, linkMon *monitor.Mon, linkSel ForwardLinkSelector) *Resolver {
|
||||
r := &Resolver{
|
||||
logf: logger.WithPrefix(logf, "dns: "),
|
||||
linkMon: linkMon,
|
||||
queue: make(chan packet, queueSize),
|
||||
responses: make(chan packet),
|
||||
errors: make(chan error),
|
||||
closed: make(chan struct{}),
|
||||
hostToIP: map[dnsname.FQDN][]netaddr.IP{},
|
||||
ipToHost: map[netaddr.IP]dnsname.FQDN{},
|
||||
}
|
||||
r.forwarder = newForwarder(r.logf, r.responses)
|
||||
if r.linkMon != nil {
|
||||
r.unregLinkMon = r.linkMon.RegisterChangeCallback(r.onLinkMonitorChange)
|
||||
}
|
||||
|
||||
r.wg.Add(1)
|
||||
go r.poll()
|
||||
|
||||
r.forwarder = newForwarder(r.logf, r.responses, linkMon, linkSel)
|
||||
return r
|
||||
}
|
||||
|
||||
@ -140,13 +149,13 @@ func (r *Resolver) SetConfig(cfg Config) error {
|
||||
|
||||
for suffix, ips := range cfg.Routes {
|
||||
routes = append(routes, route{
|
||||
suffix: suffix,
|
||||
resolvers: ips,
|
||||
Suffix: suffix,
|
||||
Resolvers: ips,
|
||||
})
|
||||
}
|
||||
// Sort from longest prefix to shortest.
|
||||
sort.Slice(routes, func(i, j int) bool {
|
||||
return routes[i].suffix.NumLabels() > routes[j].suffix.NumLabels()
|
||||
return routes[i].Suffix.NumLabels() > routes[j].Suffix.NumLabels()
|
||||
})
|
||||
|
||||
r.forwarder.setRoutes(routes)
|
||||
@ -170,19 +179,7 @@ func (r *Resolver) Close() {
|
||||
}
|
||||
close(r.closed)
|
||||
|
||||
if r.unregLinkMon != nil {
|
||||
r.unregLinkMon()
|
||||
}
|
||||
|
||||
r.forwarder.Close()
|
||||
r.wg.Wait()
|
||||
}
|
||||
|
||||
func (r *Resolver) onLinkMonitorChange(changed bool, state *interfaces.State) {
|
||||
if !changed {
|
||||
return
|
||||
}
|
||||
r.forwarder.rebindFromNetworkChange()
|
||||
}
|
||||
|
||||
// EnqueueRequest places the given DNS request in the resolver's queue.
|
||||
@ -192,11 +189,14 @@ func (r *Resolver) EnqueueRequest(bs []byte, from netaddr.IPPort) error {
|
||||
select {
|
||||
case <-r.closed:
|
||||
return ErrClosed
|
||||
case r.queue <- packet{bs, from}:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
if n := atomic.AddInt32(&r.activeQueriesAtomic, 1); n > maxActiveQueries() {
|
||||
atomic.AddInt32(&r.activeQueriesAtomic, -1)
|
||||
return errFullQueue
|
||||
}
|
||||
go r.handleQuery(packet{bs, from})
|
||||
return nil
|
||||
}
|
||||
|
||||
// NextResponse returns a DNS response to a previously enqueued request.
|
||||
@ -291,53 +291,34 @@ func (r *Resolver) resolveLocal(domain dnsname.FQDN, typ dns.Type) (netaddr.IP,
|
||||
// resolveReverse returns the unique domain name that maps to the given address.
|
||||
func (r *Resolver) resolveLocalReverse(ip netaddr.IP) (dnsname.FQDN, dns.RCode) {
|
||||
r.mu.Lock()
|
||||
ips := r.ipToHost
|
||||
r.mu.Unlock()
|
||||
|
||||
name, found := ips[ip]
|
||||
if !found {
|
||||
defer r.mu.Unlock()
|
||||
name, ok := r.ipToHost[ip]
|
||||
if !ok {
|
||||
return "", dns.RCodeNameError
|
||||
}
|
||||
return name, dns.RCodeSuccess
|
||||
}
|
||||
|
||||
func (r *Resolver) poll() {
|
||||
defer r.wg.Done()
|
||||
func (r *Resolver) handleQuery(pkt packet) {
|
||||
defer atomic.AddInt32(&r.activeQueriesAtomic, -1)
|
||||
|
||||
var pkt packet
|
||||
for {
|
||||
out, err := r.respond(pkt.bs)
|
||||
if err == errNotOurName {
|
||||
err = r.forwarder.forward(pkt)
|
||||
if err == nil {
|
||||
// forward will send response into r.responses, nothing to do.
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
select {
|
||||
case <-r.closed:
|
||||
return
|
||||
case pkt = <-r.queue:
|
||||
// continue
|
||||
case r.errors <- err:
|
||||
}
|
||||
|
||||
out, err := r.respond(pkt.bs)
|
||||
|
||||
if err == errNotOurName {
|
||||
err = r.forwarder.forward(pkt)
|
||||
if err == nil {
|
||||
// forward will send response into r.responses, nothing to do.
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
select {
|
||||
case <-r.closed:
|
||||
return
|
||||
case r.errors <- err:
|
||||
// continue
|
||||
}
|
||||
} else {
|
||||
pkt.bs = out
|
||||
select {
|
||||
case <-r.closed:
|
||||
return
|
||||
case r.responses <- pkt:
|
||||
// continue
|
||||
}
|
||||
} else {
|
||||
select {
|
||||
case <-r.closed:
|
||||
case r.responses <- packet{out, pkt.addr}:
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -351,28 +332,44 @@ type response struct {
|
||||
IP netaddr.IP
|
||||
}
|
||||
|
||||
// parseQuery parses the query in given packet into a response struct.
|
||||
// if the parse is successful, resp.Name contains the normalized name being queried.
|
||||
// TODO: stuffing the query name in resp.Name temporarily is a hack. Clean it up.
|
||||
func parseQuery(query []byte, resp *response) error {
|
||||
var parser dns.Parser
|
||||
var err error
|
||||
var dnsParserPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(dnsParser)
|
||||
},
|
||||
}
|
||||
|
||||
resp.Header, err = parser.Start(query)
|
||||
// dnsParser parses DNS queries using x/net/dns/dnsmessage.
|
||||
// These structs are pooled with dnsParserPool.
|
||||
type dnsParser struct {
|
||||
Header dns.Header
|
||||
Question dns.Question
|
||||
|
||||
parser dns.Parser
|
||||
}
|
||||
|
||||
func (p *dnsParser) response() *response {
|
||||
return &response{Header: p.Header, Question: p.Question}
|
||||
}
|
||||
|
||||
// zeroParser clears parser so it doesn't retain its most recently
|
||||
// parsed DNS query's []byte while it's sitting in a sync.Pool.
|
||||
// It's not useful to keep anyway: the next Start will do the same.
|
||||
func (p *dnsParser) zeroParser() { p.parser = dns.Parser{} }
|
||||
|
||||
// parseQuery parses the query in given packet into p.Header and
|
||||
// p.Question.
|
||||
func (p *dnsParser) parseQuery(query []byte) error {
|
||||
defer p.zeroParser()
|
||||
var err error
|
||||
p.Header, err = p.parser.Start(query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.Header.Response {
|
||||
if p.Header.Response {
|
||||
return errNotQuery
|
||||
}
|
||||
|
||||
resp.Question, err = parser.Question()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
p.Question, err = p.parser.Question()
|
||||
return err
|
||||
}
|
||||
|
||||
// marshalARecord serializes an A record into an active builder.
|
||||
@ -624,12 +621,13 @@ func (r *Resolver) respondReverse(query []byte, name dnsname.FQDN, resp *respons
|
||||
// respond returns a DNS response to query if it can be resolved locally.
|
||||
// Otherwise, it returns errNotOurName.
|
||||
func (r *Resolver) respond(query []byte) ([]byte, error) {
|
||||
resp := new(response)
|
||||
parser := dnsParserPool.Get().(*dnsParser)
|
||||
defer dnsParserPool.Put(parser)
|
||||
|
||||
// ParseQuery is sufficiently fast to run on every DNS packet.
|
||||
// This is considerably simpler than extracting the name by hand
|
||||
// to shave off microseconds in case of delegation.
|
||||
err := parseQuery(query, resp)
|
||||
err := parser.parseQuery(query)
|
||||
// We will not return this error: it is the sender's fault.
|
||||
if err != nil {
|
||||
if errors.Is(err, dns.ErrSectionDone) {
|
||||
@ -637,13 +635,15 @@ func (r *Resolver) respond(query []byte) ([]byte, error) {
|
||||
} else {
|
||||
r.logf("parseQuery(%02x): %v", query, err)
|
||||
}
|
||||
resp := parser.response()
|
||||
resp.Header.RCode = dns.RCodeFormatError
|
||||
return marshalResponse(resp)
|
||||
}
|
||||
rawName := resp.Question.Name.Data[:resp.Question.Name.Length]
|
||||
rawName := parser.Question.Name.Data[:parser.Question.Name.Length]
|
||||
name, err := dnsname.ToFQDN(rawNameToLower(rawName))
|
||||
if err != nil {
|
||||
// DNS packet unexpectedly contains an invalid FQDN.
|
||||
resp := parser.response()
|
||||
resp.Header.RCode = dns.RCodeFormatError
|
||||
return marshalResponse(resp)
|
||||
}
|
||||
@ -651,15 +651,17 @@ func (r *Resolver) respond(query []byte) ([]byte, error) {
|
||||
// Always try to handle reverse lookups; delegate inside when not found.
|
||||
// This way, queries for existent nodes do not leak,
|
||||
// but we behave gracefully if non-Tailscale nodes exist in CGNATRange.
|
||||
if resp.Question.Type == dns.TypePTR {
|
||||
return r.respondReverse(query, name, resp)
|
||||
if parser.Question.Type == dns.TypePTR {
|
||||
return r.respondReverse(query, name, parser.response())
|
||||
}
|
||||
|
||||
resp.IP, resp.Header.RCode = r.resolveLocal(name, resp.Question.Type)
|
||||
// This return code is special: it requests forwarding.
|
||||
if resp.Header.RCode == dns.RCodeRefused {
|
||||
return nil, errNotOurName
|
||||
ip, rcode := r.resolveLocal(name, parser.Question.Type)
|
||||
if rcode == dns.RCodeRefused {
|
||||
return nil, errNotOurName // sentinel error return value: it requests forwarding
|
||||
}
|
||||
|
||||
resp := parser.response()
|
||||
resp.Header.RCode = rcode
|
||||
resp.IP = ip
|
||||
return marshalResponse(resp)
|
||||
}
|
||||
|
@ -8,6 +8,7 @@
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"runtime"
|
||||
@ -17,6 +18,7 @@
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/tstest"
|
||||
"tailscale.com/util/dnsname"
|
||||
"tailscale.com/wgengine/monitor"
|
||||
)
|
||||
|
||||
var testipv4 = netaddr.MustParseIP("1.2.3.4")
|
||||
@ -128,7 +130,9 @@ func unpackResponse(payload []byte) (dnsResponse, error) {
|
||||
}
|
||||
|
||||
func syncRespond(r *Resolver, query []byte) ([]byte, error) {
|
||||
r.EnqueueRequest(query, netaddr.IPPort{})
|
||||
if err := r.EnqueueRequest(query, netaddr.IPPort{}); err != nil {
|
||||
return nil, fmt.Errorf("EnqueueRequest: %w", err)
|
||||
}
|
||||
payload, _, err := r.NextResponse()
|
||||
return payload, err
|
||||
}
|
||||
@ -211,8 +215,12 @@ func TestRDNSNameToIPv6(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func newResolver(t testing.TB) *Resolver {
|
||||
return New(t.Logf, nil /* no link monitor */, nil /* no link selector */)
|
||||
}
|
||||
|
||||
func TestResolveLocal(t *testing.T) {
|
||||
r := New(t.Logf, nil)
|
||||
r := newResolver(t)
|
||||
defer r.Close()
|
||||
|
||||
r.SetConfig(dnsCfg)
|
||||
@ -252,7 +260,7 @@ func TestResolveLocal(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolveLocalReverse(t *testing.T) {
|
||||
r := New(t.Logf, nil)
|
||||
r := newResolver(t)
|
||||
defer r.Close()
|
||||
|
||||
r.SetConfig(dnsCfg)
|
||||
@ -362,7 +370,7 @@ func TestDelegate(t *testing.T) {
|
||||
"huge.txt.", resolveToTXT(hugeTXT))
|
||||
defer v6server.Shutdown()
|
||||
|
||||
r := New(t.Logf, nil)
|
||||
r := newResolver(t)
|
||||
defer r.Close()
|
||||
|
||||
cfg := dnsCfg
|
||||
@ -474,7 +482,7 @@ func TestDelegateSplitRoute(t *testing.T) {
|
||||
"test.other.", resolveToIP(test4, test6, "dns.other."))
|
||||
defer server2.Shutdown()
|
||||
|
||||
r := New(t.Logf, nil)
|
||||
r := newResolver(t)
|
||||
defer r.Close()
|
||||
|
||||
cfg := dnsCfg
|
||||
@ -531,7 +539,7 @@ func TestDelegateCollision(t *testing.T) {
|
||||
"test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
|
||||
defer server.Shutdown()
|
||||
|
||||
r := New(t.Logf, nil)
|
||||
r := newResolver(t)
|
||||
defer r.Close()
|
||||
|
||||
cfg := dnsCfg
|
||||
@ -745,7 +753,7 @@ func TestDelegateCollision(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFull(t *testing.T) {
|
||||
r := New(t.Logf, nil)
|
||||
r := newResolver(t)
|
||||
defer r.Close()
|
||||
|
||||
r.SetConfig(dnsCfg)
|
||||
@ -781,7 +789,7 @@ func TestFull(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAllocs(t *testing.T) {
|
||||
r := New(t.Logf, nil)
|
||||
r := newResolver(t)
|
||||
defer r.Close()
|
||||
r.SetConfig(dnsCfg)
|
||||
|
||||
@ -835,7 +843,7 @@ func BenchmarkFull(b *testing.B) {
|
||||
"test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
|
||||
defer server.Shutdown()
|
||||
|
||||
r := New(b.Logf, nil)
|
||||
r := newResolver(b)
|
||||
defer r.Close()
|
||||
|
||||
cfg := dnsCfg
|
||||
@ -872,3 +880,58 @@ func TestMarshalResponseFormatError(t *testing.T) {
|
||||
}
|
||||
t.Logf("response: %q", v)
|
||||
}
|
||||
|
||||
func TestForwardLinkSelection(t *testing.T) {
|
||||
old := initListenConfig
|
||||
defer func() { initListenConfig = old }()
|
||||
|
||||
configCall := make(chan string, 1)
|
||||
initListenConfig = func(nc *net.ListenConfig, mon *monitor.Mon, tunName string) error {
|
||||
select {
|
||||
case configCall <- tunName:
|
||||
return nil
|
||||
default:
|
||||
t.Error("buffer full")
|
||||
return errors.New("buffer full")
|
||||
}
|
||||
}
|
||||
|
||||
// specialIP is some IP we pretend that our link selector
|
||||
// routes differently.
|
||||
specialIP := netaddr.IPv4(1, 2, 3, 4)
|
||||
|
||||
fwd := newForwarder(t.Logf, nil, nil, linkSelFunc(func(ip netaddr.IP) string {
|
||||
if ip == netaddr.IPv4(1, 2, 3, 4) {
|
||||
return "special"
|
||||
}
|
||||
return ""
|
||||
}))
|
||||
|
||||
// Test non-special IP.
|
||||
if got, err := fwd.packetListener(netaddr.IP{}); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if got != stdNetPacketListener {
|
||||
t.Errorf("for IP zero value, didn't get expected packet listener")
|
||||
}
|
||||
select {
|
||||
case v := <-configCall:
|
||||
t.Errorf("unexpected ListenConfig call, with tunName %q", v)
|
||||
default:
|
||||
}
|
||||
|
||||
// Test that our special IP generates a call to initListenConfig.
|
||||
if got, err := fwd.packetListener(specialIP); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if got == stdNetPacketListener {
|
||||
t.Errorf("special IP returned std packet listener; expected unique one")
|
||||
}
|
||||
if v, ok := <-configCall; !ok {
|
||||
t.Errorf("didn't get ListenConfig call")
|
||||
} else if v != "special" {
|
||||
t.Errorf("got tunName %q; want 'special'", v)
|
||||
}
|
||||
}
|
||||
|
||||
type linkSelFunc func(ip netaddr.IP) string
|
||||
|
||||
func (f linkSelFunc) PickLink(ip netaddr.IP) string { return f(ip) }
|
||||
|
53
net/netns/netns_macios.go
Normal file
53
net/netns/netns_macios.go
Normal file
@ -0,0 +1,53 @@
|
||||
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build darwin ios
|
||||
|
||||
package netns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// SetListenConfigInterfaceIndex sets lc.Control such that sockets are bound
|
||||
// to the provided interface index.
|
||||
func SetListenConfigInterfaceIndex(lc *net.ListenConfig, ifIndex int) error {
|
||||
if lc == nil {
|
||||
return errors.New("nil ListenConfig")
|
||||
}
|
||||
if lc.Control != nil {
|
||||
return errors.New("ListenConfig.Control already set")
|
||||
}
|
||||
lc.Control = func(network, address string, c syscall.RawConn) error {
|
||||
var sockErr error
|
||||
err := c.Control(func(fd uintptr) {
|
||||
sockErr = bindInterface(fd, network, address, ifIndex)
|
||||
if sockErr != nil {
|
||||
log.Printf("netns: bind(%q, %q) on index %v: %v", network, address, ifIndex, sockErr)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sockErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func bindInterface(fd uintptr, network, address string, ifIndex int) error {
|
||||
v6 := strings.Contains(address, "]:") || strings.HasSuffix(network, "6") // hacky test for v6
|
||||
proto := unix.IPPROTO_IP
|
||||
opt := unix.IP_BOUND_IF
|
||||
if v6 {
|
||||
proto = unix.IPPROTO_IPV6
|
||||
opt = unix.IPV6_BOUND_IF
|
||||
}
|
||||
return unix.SetsockoptInt(int(fd), proto, opt, ifIndex)
|
||||
}
|
@ -15,19 +15,19 @@
|
||||
)
|
||||
|
||||
func ResourceCheck(tb testing.TB) {
|
||||
tb.Helper()
|
||||
startN, startStacks := goroutines()
|
||||
tb.Cleanup(func() {
|
||||
if tb.Failed() {
|
||||
// Something else went wrong.
|
||||
return
|
||||
}
|
||||
tb.Helper()
|
||||
// Goroutines might be still exiting.
|
||||
for i := 0; i < 100; i++ {
|
||||
if runtime.NumGoroutine() <= startN {
|
||||
return
|
||||
}
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
endN, endStacks := goroutines()
|
||||
tb.Logf("goroutine diff:\n%v\n", cmp.Diff(startStacks, endStacks))
|
||||
|
@ -101,6 +101,10 @@ type userspaceEngine struct {
|
||||
// incorrectly sent to us.
|
||||
isLocalAddr atomic.Value // of func(netaddr.IP)bool
|
||||
|
||||
// isDNSIPOverTailscale reports the whether a DNS resolver's IP
|
||||
// is being routed over Tailscale.
|
||||
isDNSIPOverTailscale atomic.Value // of func(netaddr.IP)bool
|
||||
|
||||
wgLock sync.Mutex // serializes all wgdev operations; see lock order comment below
|
||||
lastCfgFull wgcfg.Config
|
||||
lastRouterSig string // of router.Config
|
||||
@ -242,6 +246,7 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
|
||||
confListenPort: conf.ListenPort,
|
||||
}
|
||||
e.isLocalAddr.Store(tsaddr.NewContainsIPFunc(nil))
|
||||
e.isDNSIPOverTailscale.Store(tsaddr.NewContainsIPFunc(nil))
|
||||
|
||||
if conf.LinkMonitor != nil {
|
||||
e.linkMon = conf.LinkMonitor
|
||||
@ -255,7 +260,8 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
|
||||
e.linkMonOwned = true
|
||||
}
|
||||
|
||||
e.dns = dns.NewManager(logf, conf.DNS, e.linkMon)
|
||||
tunName, _ := conf.Tun.Name()
|
||||
e.dns = dns.NewManager(logf, conf.DNS, e.linkMon, fwdDNSLinkSelector{e, tunName})
|
||||
|
||||
logf("link state: %+v", e.linkMon.InterfaceState())
|
||||
|
||||
@ -767,6 +773,13 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config,
|
||||
return ErrNoChanges
|
||||
}
|
||||
|
||||
// TODO(bradfitz,danderson): maybe delete this isDNSIPOverTailscale
|
||||
// field and delete the resolver.ForwardLinkSelector hook and
|
||||
// instead have ipnlocal populate a map of DNS IP => linkName and
|
||||
// put that in the *dns.Config instead, and plumb it down to the
|
||||
// dns.Manager. Maybe also with isLocalAddr above.
|
||||
e.isDNSIPOverTailscale.Store(tsaddr.NewContainsIPFunc(dnsIPsOverTailscale(dnsCfg, routerCfg)))
|
||||
|
||||
// See if any peers have changed disco keys, which means they've restarted.
|
||||
// If so, we need to update the wireguard-go/device.Device in two phases:
|
||||
// once without the node which has restarted, to clear its wireguard session key,
|
||||
@ -1362,3 +1375,49 @@ func (p closeOnErrorPool) closeAllIfError(errp *error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ipInPrefixes reports whether ip is in any of pp.
|
||||
func ipInPrefixes(ip netaddr.IP, pp []netaddr.IPPrefix) bool {
|
||||
for _, p := range pp {
|
||||
if p.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// dnsIPsOverTailscale returns the IPPrefixes of DNS resolver IPs that are
|
||||
// routed over Tailscale. The returned value does not contain duplicates is
|
||||
// not necessarily sorted.
|
||||
func dnsIPsOverTailscale(dnsCfg *dns.Config, routerCfg *router.Config) (ret []netaddr.IPPrefix) {
|
||||
m := map[netaddr.IP]bool{}
|
||||
|
||||
for _, resolvers := range dnsCfg.Routes {
|
||||
for _, resolver := range resolvers {
|
||||
ip := resolver.IP()
|
||||
if ipInPrefixes(ip, routerCfg.Routes) && !ipInPrefixes(ip, routerCfg.LocalRoutes) {
|
||||
m[ip] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ret = make([]netaddr.IPPrefix, 0, len(m))
|
||||
for ip := range m {
|
||||
ret = append(ret, netaddr.IPPrefixFrom(ip, ip.BitLen()))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// fwdDNSLinkSelector is userspaceEngine's resolver.ForwardLinkSelector, to pick
|
||||
// which network interface to send DNS queries out of.
|
||||
type fwdDNSLinkSelector struct {
|
||||
ue *userspaceEngine
|
||||
tunName string
|
||||
}
|
||||
|
||||
func (ls fwdDNSLinkSelector) PickLink(ip netaddr.IP) (linkName string) {
|
||||
if ls.ue.isDNSIPOverTailscale.Load().(func(netaddr.IP) bool)(ip) {
|
||||
return ls.tunName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user