From 45e64f2e1abdb3c9feced24697c11c667eb0351d Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 22 Jun 2021 21:53:43 -0700 Subject: [PATCH] net/dns{,/resolver}: refactor DNS forwarder, send out of right link on macOS/iOS Fixes #2224 Fixes tailscale/corp#2045 Signed-off-by: Brad Fitzpatrick --- ipn/ipnlocal/peerapi_macios_ext.go | 39 +- net/dns/manager.go | 6 +- net/dns/manager_test.go | 2 +- net/dns/resolver/forwarder.go | 563 +++++++++++------------------ net/dns/resolver/macios_ext.go | 27 ++ net/dns/resolver/tsdns.go | 194 +++++----- net/dns/resolver/tsdns_test.go | 81 ++++- net/netns/netns_macios.go | 53 +++ tstest/resource.go | 4 +- wgengine/userspace.go | 61 +++- 10 files changed, 530 insertions(+), 500 deletions(-) create mode 100644 net/dns/resolver/macios_ext.go create mode 100644 net/netns/netns_macios.go diff --git a/ipn/ipnlocal/peerapi_macios_ext.go b/ipn/ipnlocal/peerapi_macios_ext.go index 081430bcf..24568aeba 100644 --- a/ipn/ipnlocal/peerapi_macios_ext.go +++ b/ipn/ipnlocal/peerapi_macios_ext.go @@ -9,14 +9,12 @@ package ipnlocal import ( "errors" "fmt" - "log" "net" - "strings" "syscall" - "golang.org/x/sys/unix" "inet.af/netaddr" "tailscale.com/net/interfaces" + "tailscale.com/net/netns" ) func init() { @@ -32,29 +30,7 @@ func initListenConfigNetworkExtension(nc *net.ListenConfig, ip netaddr.IP, st *i if !ok { return fmt.Errorf("no interface with name %q", tunIfName) } - nc.Control = func(network, address string, c syscall.RawConn) error { - var sockErr error - err := c.Control(func(fd uintptr) { - sockErr = bindIf(fd, network, address, tunIf.Index) - log.Printf("peerapi: bind(%q, %q) on index %v = %v", network, address, tunIf.Index, sockErr) - }) - if err != nil { - return err - } - return sockErr - } - return nil -} - -func bindIf(fd uintptr, network, address string, ifIndex int) error { - v6 := strings.Contains(address, "]:") || strings.HasSuffix(network, "6") // hacky test for v6 - proto := unix.IPPROTO_IP - opt := unix.IP_BOUND_IF - if v6 { - proto = unix.IPPROTO_IPV6 - opt = unix.IPV6_BOUND_IF - } - return unix.SetsockoptInt(int(fd), proto, opt, ifIndex) + return netns.SetListenConfigInterfaceIndex(nc, tunIf.Index) } func peerDialControlFuncNetworkExtension(b *LocalBackend) func(network, address string, c syscall.RawConn) error { @@ -68,17 +44,12 @@ func peerDialControlFuncNetworkExtension(b *LocalBackend) func(network, address index = tunIf.Index } } + var lc net.ListenConfig + netns.SetListenConfigInterfaceIndex(&lc, index) return func(network, address string, c syscall.RawConn) error { if index == -1 { return errors.New("failed to find TUN interface to bind to") } - var sockErr error - err := c.Control(func(fd uintptr) { - sockErr = bindIf(fd, network, address, index) - }) - if err != nil { - return err - } - return sockErr + return lc.Control(network, address, c) } } diff --git a/net/dns/manager.go b/net/dns/manager.go index 85d87c7bd..398db6837 100644 --- a/net/dns/manager.go +++ b/net/dns/manager.go @@ -38,11 +38,11 @@ type Manager struct { } // NewManagers created a new manager from the given config. -func NewManager(logf logger.Logf, oscfg OSConfigurator, linkMon *monitor.Mon) *Manager { +func NewManager(logf logger.Logf, oscfg OSConfigurator, linkMon *monitor.Mon, linkSel resolver.ForwardLinkSelector) *Manager { logf = logger.WithPrefix(logf, "dns: ") m := &Manager{ logf: logf, - resolver: resolver.New(logf, linkMon), + resolver: resolver.New(logf, linkMon, linkSel), os: oscfg, } m.logf("using %T", m.os) @@ -207,7 +207,7 @@ func Cleanup(logf logger.Logf, interfaceName string) { logf("creating dns cleanup: %v", err) return } - dns := NewManager(logf, oscfg, nil) + dns := NewManager(logf, oscfg, nil, nil) if err := dns.Down(); err != nil { logf("dns down: %v", err) } diff --git a/net/dns/manager_test.go b/net/dns/manager_test.go index 8835f2aec..8b1c32e49 100644 --- a/net/dns/manager_test.go +++ b/net/dns/manager_test.go @@ -376,7 +376,7 @@ func TestManager(t *testing.T) { SplitDNS: test.split, BaseConfig: test.bs, } - m := NewManager(t.Logf, &f, nil) + m := NewManager(t.Logf, &f, nil, nil) m.resolver.TestOnlySetHook(f.SetResolver) if err := m.Set(test.in); err != nil { diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index 6178e26fe..866b4f045 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -9,41 +9,30 @@ import ( "context" "encoding/binary" "errors" - "fmt" "hash/crc32" + "io" "math/rand" "net" "sync" - "syscall" "time" dns "golang.org/x/net/dns/dnsmessage" "inet.af/netaddr" - "tailscale.com/logtail/backoff" "tailscale.com/types/logger" "tailscale.com/util/dnsname" + "tailscale.com/wgengine/monitor" ) // headerBytes is the number of bytes in a DNS message header. const headerBytes = 12 -// connCount is the number of UDP connections to use for forwarding. -const connCount = 32 - const ( - // cleanupInterval is the interval between purged of timed-out entries from txMap. - cleanupInterval = 30 * time.Second // responseTimeout is the maximal amount of time to wait for a DNS response. responseTimeout = 5 * time.Second ) var errNoUpstreams = errors.New("upstream nameservers not set") -type forwardingRecord struct { - src netaddr.IPPort - createdAt time.Time -} - // txid identifies a DNS transaction. // // As the standard DNS Request ID is only 16 bits, we extend it: @@ -100,178 +89,164 @@ func getTxID(packet []byte) txid { } type route struct { - suffix dnsname.FQDN - resolvers []netaddr.IPPort + Suffix dnsname.FQDN + Resolvers []netaddr.IPPort } // forwarder forwards DNS packets to a number of upstream nameservers. type forwarder struct { - logf logger.Logf + logf logger.Logf + linkMon *monitor.Mon + linkSel ForwardLinkSelector + + ctx context.Context // good until Close + ctxCancel context.CancelFunc // closes ctx // responses is a channel by which responses are returned. responses chan packet - // closed signals all goroutines to stop. - closed chan struct{} - // wg signals when all goroutines have stopped. - wg sync.WaitGroup - // conns are the UDP connections used for forwarding. - // A random one is selected for each request, regardless of the target upstream. - conns []*fwdConn + mu sync.Mutex // guards following - mu sync.Mutex - // routes are per-suffix resolvers to use. - routes []route // most specific routes first - txMap map[txid]forwardingRecord // txids to in-flight requests + // routes are per-suffix resolvers to use, with + // the most specific routes first. + routes []route } func init() { rand.Seed(time.Now().UnixNano()) } -func newForwarder(logf logger.Logf, responses chan packet) *forwarder { - ret := &forwarder{ +func newForwarder(logf logger.Logf, responses chan packet, linkMon *monitor.Mon, linkSel ForwardLinkSelector) *forwarder { + f := &forwarder{ logf: logger.WithPrefix(logf, "forward: "), + linkMon: linkMon, + linkSel: linkSel, responses: responses, - closed: make(chan struct{}), - conns: make([]*fwdConn, connCount), - txMap: make(map[txid]forwardingRecord), } - - ret.wg.Add(connCount + 1) - for idx := range ret.conns { - ret.conns[idx] = newFwdConn(ret.logf, idx) - go ret.recv(ret.conns[idx]) - } - go ret.cleanMap() - - return ret + f.ctx, f.ctxCancel = context.WithCancel(context.Background()) + return f } -func (f *forwarder) Close() { - select { - case <-f.closed: - return - default: - // continue - } - close(f.closed) - - for _, conn := range f.conns { - conn.close() - } - - f.wg.Wait() -} - -func (f *forwarder) rebindFromNetworkChange() { - for _, c := range f.conns { - c.mu.Lock() - c.reconnectLocked() - c.mu.Unlock() - } +func (f *forwarder) Close() error { + f.ctxCancel() + return nil } func (f *forwarder) setRoutes(routes []route) { f.mu.Lock() + defer f.mu.Unlock() f.routes = routes - f.mu.Unlock() +} + +var stdNetPacketListener packetListener = new(net.ListenConfig) + +type packetListener interface { + ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) +} + +func (f *forwarder) packetListener(ip netaddr.IP) (packetListener, error) { + if f.linkSel == nil || initListenConfig == nil { + return stdNetPacketListener, nil + } + linkName := f.linkSel.PickLink(ip) + if linkName == "" { + return stdNetPacketListener, nil + } + lc := new(net.ListenConfig) + if err := initListenConfig(lc, f.linkMon, linkName); err != nil { + return nil, err + } + return lc, nil } // send sends packet to dst. It is best effort. -func (f *forwarder) send(packet []byte, dst netaddr.IPPort) { - connIdx := rand.Intn(connCount) - conn := f.conns[connIdx] - conn.send(packet, dst) -} +// +// send expects the reply to have the same txid as txidOut. +// +// The provided closeOnCtxDone lets send register values to Close if +// the caller's ctx expires. This avoids send from allocating its own +// waiting goroutine to interrupt the ReadFrom, as memory is tight on +// iOS and we want the number of pending DNS lookups to be bursty +// without too much associated goroutine/memory cost. +func (f *forwarder) send(ctx context.Context, txidOut txid, closeOnCtxDone *closePool, packet []byte, dst netaddr.IPPort) ([]byte, error) { + // TODO(bradfitz): if dst.IP is 8.8.8.8 or 8.8.4.4 or 1.1.1.1, etc, or + // something dynamically probed earlier to support DoH or DoT, + // do that here instead. -func (f *forwarder) recv(conn *fwdConn) { - defer f.wg.Done() + ln, err := f.packetListener(dst.IP()) + if err != nil { + return nil, err + } + conn, err := ln.ListenPacket(ctx, "udp", ":0") + if err != nil { + f.logf("ListenPacket failed: %v", err) + return nil, err + } + defer conn.Close() - for { - select { - case <-f.closed: - return - default: + closeOnCtxDone.Add(conn) + defer closeOnCtxDone.Remove(conn) + + if _, err := conn.WriteTo(packet, dst.UDPAddr()); err != nil { + if err := ctx.Err(); err != nil { + return nil, err } - // The 1 extra byte is to detect packet truncation. - out := make([]byte, maxResponseBytes+1) - n := conn.read(out) - var truncated bool - if n > maxResponseBytes { - n = maxResponseBytes - truncated = true + return nil, err + } + + // The 1 extra byte is to detect packet truncation. + out := make([]byte, maxResponseBytes+1) + n, _, err := conn.ReadFrom(out) + if err != nil { + if err := ctx.Err(); err != nil { + return nil, err } - if n == 0 { - continue - } - if n < headerBytes { - f.logf("recv: packet too small (%d bytes)", n) - } - - out = out[:n] - txid := getTxID(out) - - if truncated { - const dnsFlagTruncated = 0x200 - flags := binary.BigEndian.Uint16(out[2:4]) - flags |= dnsFlagTruncated - binary.BigEndian.PutUint16(out[2:4], flags) - - // TODO(#2067): Remove any incomplete records? RFC 1035 section 6.2 - // states that truncation should head drop so that the authority - // section can be preserved if possible. However, the UDP read with - // a too-small buffer has already dropped the end, so that's the - // best we can do. - } - - f.mu.Lock() - - record, found := f.txMap[txid] - // At most one nameserver will return a response: - // the first one to do so will delete txid from the map. - if !found { - f.mu.Unlock() - continue - } - delete(f.txMap, txid) - - f.mu.Unlock() - - pkt := packet{out, record.src} - select { - case <-f.closed: - return - case f.responses <- pkt: - // continue + if packetWasTruncated(err) { + err = nil + } else { + return nil, err } } + truncated := n > maxResponseBytes + if truncated { + n = maxResponseBytes + } + if n < headerBytes { + f.logf("recv: packet too small (%d bytes)", n) + } + out = out[:n] + txid := getTxID(out) + if txid != txidOut { + return nil, errors.New("txid doesn't match") + } + + if truncated { + const dnsFlagTruncated = 0x200 + flags := binary.BigEndian.Uint16(out[2:4]) + flags |= dnsFlagTruncated + binary.BigEndian.PutUint16(out[2:4], flags) + + // TODO(#2067): Remove any incomplete records? RFC 1035 section 6.2 + // states that truncation should head drop so that the authority + // section can be preserved if possible. However, the UDP read with + // a too-small buffer has already dropped the end, so that's the + // best we can do. + } + + return out, nil } -// cleanMap periodically deletes timed-out forwarding records from f.txMap to bound growth. -func (f *forwarder) cleanMap() { - defer f.wg.Done() - - t := time.NewTicker(cleanupInterval) - defer t.Stop() - - var now time.Time - for { - select { - case <-f.closed: - return - case now = <-t.C: - // continue +// resolvers returns the resolvers to use for domain. +func (f *forwarder) resolvers(domain dnsname.FQDN) []netaddr.IPPort { + f.mu.Lock() + routes := f.routes + f.mu.Unlock() + for _, route := range routes { + if route.Suffix == "." || route.Suffix.Contains(domain) { + return route.Resolvers } - - f.mu.Lock() - for k, v := range f.txMap { - if now.Sub(v.createdAt) > responseTimeout { - delete(f.txMap, k) - } - } - f.mu.Unlock() } + return nil } // forward forwards the query to all upstream nameservers and returns the first response. @@ -283,225 +258,60 @@ func (f *forwarder) forward(query packet) error { txid := getTxID(query.bs) - f.mu.Lock() - routes := f.routes - f.mu.Unlock() - - var resolvers []netaddr.IPPort - for _, route := range routes { - if route.suffix != "." && !route.suffix.Contains(domain) { - continue - } - resolvers = route.resolvers - break - } + resolvers := f.resolvers(domain) if len(resolvers) == 0 { return errNoUpstreams } - f.mu.Lock() - f.txMap[txid] = forwardingRecord{ - src: query.addr, - createdAt: time.Now(), - } - f.mu.Unlock() + closeOnCtxDone := new(closePool) + defer closeOnCtxDone.Close() - // TODO(#2066): EDNS size clamping + ctx, cancel := context.WithTimeout(f.ctx, responseTimeout) + defer cancel() - for _, resolver := range resolvers { - f.send(query.bs, resolver) - } + resc := make(chan []byte, 1) + var ( + mu sync.Mutex + firstErr error + ) - return nil -} - -// A fwdConn manages a single connection used to forward DNS requests. -// Net link changes can cause a *net.UDPConn to become permanently unusable, particularly on macOS. -// fwdConn detects such situations and transparently creates new connections. -type fwdConn struct { - // logf allows a fwdConn to log. - logf logger.Logf - - // change allows calls to read to block until a the network connection has been replaced. - change *sync.Cond - - // mu protects fields that follow it; it is also change's Locker. - mu sync.Mutex - // closed tracks whether fwdConn has been permanently closed. - closed bool - // conn is the current active connection. - conn net.PacketConn -} - -func newFwdConn(logf logger.Logf, idx int) *fwdConn { - c := new(fwdConn) - c.logf = logger.WithPrefix(logf, fmt.Sprintf("fwdConn %d: ", idx)) - c.change = sync.NewCond(&c.mu) - // c.conn is created lazily in send - return c -} - -// send sends packet to dst using c's connection. -// It is best effort. It is UDP, after all. Failures are logged. -func (c *fwdConn) send(packet []byte, dst netaddr.IPPort) { - var b *backoff.Backoff // lazily initialized, since it is not needed in the common case - backOff := func(err error) { - if b == nil { - b = backoff.NewBackoff("dns-fwdConn-send", c.logf, 30*time.Second) - } - b.BackOff(context.Background(), err) - } - - for { - // Gather the current connection. - // We can't hold the lock while we call WriteTo. - c.mu.Lock() - conn := c.conn - closed := c.closed - if closed { - c.mu.Unlock() - return - } - if conn == nil { - c.reconnectLocked() - c.mu.Unlock() - continue - } - c.mu.Unlock() - - _, err := conn.WriteTo(packet, dst.UDPAddr()) - if err == nil { - // Success - return - } - if errors.Is(err, net.ErrClosed) { - // We intentionally closed this connection. - // It has been replaced by a new connection. Try again. - continue - } - // Something else went wrong. - // We have three choices here: try again, give up, or create a new connection. - var opErr *net.OpError - if !errors.As(err, &opErr) { - // Weird. All errors from the net package should be *net.OpError. Bail. - c.logf("send: non-*net.OpErr %v (%T)", err, err) - return - } - if opErr.Temporary() || opErr.Timeout() { - // I doubt that either of these can happen (this is UDP), - // but go ahead and try again. - backOff(err) - continue - } - if errors.Is(err, syscall.EHOSTUNREACH) { - // "No route to host." The network stack is fine, but - // can't talk to this destination. Not much we can do - // about that, don't spam logs. - return - } - if networkIsDown(err) { - // Fail. - c.logf("send: network is down") - return - } - if networkIsUnreachable(err) { - // This can be caused by a link change. - // Replace the existing connection with a new one. - c.mu.Lock() - // It's possible that multiple senders discovered simultaneously - // that the network is unreachable. Avoid reconnecting multiple times: - // Only reconnect if the current connection is the one that we - // discovered to be problematic. - if c.conn == conn { - backOff(err) - c.reconnectLocked() + for _, ipp := range resolvers { + go func(ipp netaddr.IPPort) { + resb, err := f.send(ctx, txid, closeOnCtxDone, query.bs, ipp) + if err != nil { + mu.Lock() + defer mu.Unlock() + if firstErr == nil { + firstErr = err + } + return } - c.mu.Unlock() - // Try again with our new network connection. - continue + select { + case resc <- resb: + default: + } + }(ipp) + } + + select { + case v := <-resc: + select { + case <-ctx.Done(): + return ctx.Err() + case f.responses <- packet{v, query.addr}: + return nil } - // Unrecognized error. Fail. - c.logf("send: unrecognized error: %v", err) - return + case <-ctx.Done(): + mu.Lock() + defer mu.Unlock() + if firstErr != nil { + return firstErr + } + return ctx.Err() } } -// read waits for a response from c's connection. -// It returns the number of bytes read, which may be 0 -// in case of an error or a closed connection. -func (c *fwdConn) read(out []byte) int { - for { - // Gather the current connection. - // We can't hold the lock while we call ReadFrom. - c.mu.Lock() - conn := c.conn - closed := c.closed - if closed { - c.mu.Unlock() - return 0 - } - if conn == nil { - // There is no current connection. - // Wait for the connection to change, then try again. - c.change.Wait() - c.mu.Unlock() - continue - } - c.mu.Unlock() - - n, _, err := conn.ReadFrom(out) - if err == nil || packetWasTruncated(err) { - // Success. - return n - } - if errors.Is(err, net.ErrClosed) { - // We intentionally closed this connection. - // It has been replaced by a new connection. Try again. - continue - } - - c.logf("read: unrecognized error: %v", err) - return 0 - } -} - -// reconnectLocked replaces the current connection with a new one. -// c.mu must be locked. -func (c *fwdConn) reconnectLocked() { - c.closeConnLocked() - // Make a new connection. - conn, err := net.ListenPacket("udp", "") - if err != nil { - c.logf("ListenPacket failed: %v", err) - } else { - c.conn = conn - } - // Broadcast that a new connection is available. - c.change.Broadcast() -} - -// closeCurrentConn closes the current connection. -// c.mu must be locked. -func (c *fwdConn) closeConnLocked() { - if c.conn == nil { - return - } - c.conn.Close() // unblocks all readers/writers, they'll pick up the next connection. - c.conn = nil -} - -// close permanently closes c. -func (c *fwdConn) close() { - c.mu.Lock() - defer c.mu.Unlock() - if c.closed { - return - } - c.closed = true - c.closeConnLocked() - // Unblock any remaining readers. - c.change.Broadcast() -} +var initListenConfig func(_ *net.ListenConfig, _ *monitor.Mon, tunName string) error // nameFromQuery extracts the normalized query name from bs. func nameFromQuery(bs []byte) (dnsname.FQDN, error) { @@ -523,3 +333,48 @@ func nameFromQuery(bs []byte) (dnsname.FQDN, error) { n := q.Name.Data[:q.Name.Length] return dnsname.ToFQDN(rawNameToLower(n)) } + +// closePool is a dynamic set of io.Closers to close as a group. +// It's intended to be Closed at most once. +// +// The zero value is ready for use. +type closePool struct { + mu sync.Mutex + m map[io.Closer]bool + closed bool +} + +func (p *closePool) Add(c io.Closer) { + p.mu.Lock() + defer p.mu.Unlock() + if p.closed { + c.Close() + return + } + if p.m == nil { + p.m = map[io.Closer]bool{} + } + p.m[c] = true +} + +func (p *closePool) Remove(c io.Closer) { + p.mu.Lock() + defer p.mu.Unlock() + if p.closed { + return + } + delete(p.m, c) +} + +func (p *closePool) Close() error { + p.mu.Lock() + defer p.mu.Unlock() + if p.closed { + return nil + } + p.closed = true + for c := range p.m { + c.Close() + } + return nil +} diff --git a/net/dns/resolver/macios_ext.go b/net/dns/resolver/macios_ext.go new file mode 100644 index 000000000..b19dcdeb8 --- /dev/null +++ b/net/dns/resolver/macios_ext.go @@ -0,0 +1,27 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin,ts_macext ios,ts_macext + +package resolver + +import ( + "errors" + "net" + + "tailscale.com/net/netns" + "tailscale.com/wgengine/monitor" +) + +func init() { + initListenConfig = initListenConfigNetworkExtension +} + +func initListenConfigNetworkExtension(nc *net.ListenConfig, mon *monitor.Mon, tunName string) error { + nif, ok := mon.InterfaceState().Interface[tunName] + if !ok { + return errors.New("utun not found") + } + return netns.SetListenConfigInterfaceIndex(nc, nif.Interface.Index) +} diff --git a/net/dns/resolver/tsdns.go b/net/dns/resolver/tsdns.go index 105efa720..96738779f 100644 --- a/net/dns/resolver/tsdns.go +++ b/net/dns/resolver/tsdns.go @@ -9,14 +9,15 @@ package resolver import ( "encoding/hex" "errors" + "runtime" "sort" "strings" "sync" + "sync/atomic" "time" dns "golang.org/x/net/dns/dnsmessage" "inet.af/netaddr" - "tailscale.com/net/interfaces" "tailscale.com/types/logger" "tailscale.com/util/dnsname" "tailscale.com/wgengine/monitor" @@ -27,10 +28,20 @@ import ( // truncation in a platform-agnostic way. const maxResponseBytes = 4095 -// queueSize is the maximal number of DNS requests that can await polling. +// maxActiveQueries returns the maximal number of DNS requests that be +// can running. // If EnqueueRequest is called when this many requests are already pending, // the request will be dropped to avoid blocking the caller. -const queueSize = 64 +func maxActiveQueries() int32 { + if runtime.GOOS == "ios" { + // For memory paranoia reasons on iOS, match the + // historical Tailscale 1.x..1.8 behavior for now + // (just before the 1.10 release). + return 64 + } + // But for other platforms, allow more burstiness: + return 256 +} // defaultTTL is the TTL of all responses from Resolver. const defaultTTL = 600 * time.Second @@ -75,13 +86,12 @@ type Config struct { type Resolver struct { logf logger.Logf linkMon *monitor.Mon // or nil - unregLinkMon func() // or nil saveConfigForTests func(cfg Config) // used in tests to capture resolver config // forwarder forwards requests to upstream nameservers. forwarder *forwarder - // queue is a buffered channel holding DNS requests queued for resolution. - queue chan packet + activeQueriesAtomic int32 // number of DNS queries in flight + // responses is an unbuffered channel to which responses are returned. responses chan packet // errors is an unbuffered channel to which errors are returned. @@ -98,27 +108,26 @@ type Resolver struct { ipToHost map[netaddr.IP]dnsname.FQDN } +type ForwardLinkSelector interface { + // PickLink returns which network device should be used to query + // the DNS server at the given IP. + // The empty string means to use an unspecified default. + PickLink(netaddr.IP) (linkName string) +} + // New returns a new resolver. // linkMon optionally specifies a link monitor to use for socket rebinding. -func New(logf logger.Logf, linkMon *monitor.Mon) *Resolver { +func New(logf logger.Logf, linkMon *monitor.Mon, linkSel ForwardLinkSelector) *Resolver { r := &Resolver{ logf: logger.WithPrefix(logf, "dns: "), linkMon: linkMon, - queue: make(chan packet, queueSize), responses: make(chan packet), errors: make(chan error), closed: make(chan struct{}), hostToIP: map[dnsname.FQDN][]netaddr.IP{}, ipToHost: map[netaddr.IP]dnsname.FQDN{}, } - r.forwarder = newForwarder(r.logf, r.responses) - if r.linkMon != nil { - r.unregLinkMon = r.linkMon.RegisterChangeCallback(r.onLinkMonitorChange) - } - - r.wg.Add(1) - go r.poll() - + r.forwarder = newForwarder(r.logf, r.responses, linkMon, linkSel) return r } @@ -140,13 +149,13 @@ func (r *Resolver) SetConfig(cfg Config) error { for suffix, ips := range cfg.Routes { routes = append(routes, route{ - suffix: suffix, - resolvers: ips, + Suffix: suffix, + Resolvers: ips, }) } // Sort from longest prefix to shortest. sort.Slice(routes, func(i, j int) bool { - return routes[i].suffix.NumLabels() > routes[j].suffix.NumLabels() + return routes[i].Suffix.NumLabels() > routes[j].Suffix.NumLabels() }) r.forwarder.setRoutes(routes) @@ -170,19 +179,7 @@ func (r *Resolver) Close() { } close(r.closed) - if r.unregLinkMon != nil { - r.unregLinkMon() - } - r.forwarder.Close() - r.wg.Wait() -} - -func (r *Resolver) onLinkMonitorChange(changed bool, state *interfaces.State) { - if !changed { - return - } - r.forwarder.rebindFromNetworkChange() } // EnqueueRequest places the given DNS request in the resolver's queue. @@ -192,11 +189,14 @@ func (r *Resolver) EnqueueRequest(bs []byte, from netaddr.IPPort) error { select { case <-r.closed: return ErrClosed - case r.queue <- packet{bs, from}: - return nil default: + } + if n := atomic.AddInt32(&r.activeQueriesAtomic, 1); n > maxActiveQueries() { + atomic.AddInt32(&r.activeQueriesAtomic, -1) return errFullQueue } + go r.handleQuery(packet{bs, from}) + return nil } // NextResponse returns a DNS response to a previously enqueued request. @@ -291,53 +291,34 @@ func (r *Resolver) resolveLocal(domain dnsname.FQDN, typ dns.Type) (netaddr.IP, // resolveReverse returns the unique domain name that maps to the given address. func (r *Resolver) resolveLocalReverse(ip netaddr.IP) (dnsname.FQDN, dns.RCode) { r.mu.Lock() - ips := r.ipToHost - r.mu.Unlock() - - name, found := ips[ip] - if !found { + defer r.mu.Unlock() + name, ok := r.ipToHost[ip] + if !ok { return "", dns.RCodeNameError } return name, dns.RCodeSuccess } -func (r *Resolver) poll() { - defer r.wg.Done() +func (r *Resolver) handleQuery(pkt packet) { + defer atomic.AddInt32(&r.activeQueriesAtomic, -1) - var pkt packet - for { + out, err := r.respond(pkt.bs) + if err == errNotOurName { + err = r.forwarder.forward(pkt) + if err == nil { + // forward will send response into r.responses, nothing to do. + return + } + } + if err != nil { select { case <-r.closed: - return - case pkt = <-r.queue: - // continue + case r.errors <- err: } - - out, err := r.respond(pkt.bs) - - if err == errNotOurName { - err = r.forwarder.forward(pkt) - if err == nil { - // forward will send response into r.responses, nothing to do. - continue - } - } - - if err != nil { - select { - case <-r.closed: - return - case r.errors <- err: - // continue - } - } else { - pkt.bs = out - select { - case <-r.closed: - return - case r.responses <- pkt: - // continue - } + } else { + select { + case <-r.closed: + case r.responses <- packet{out, pkt.addr}: } } } @@ -351,28 +332,44 @@ type response struct { IP netaddr.IP } -// parseQuery parses the query in given packet into a response struct. -// if the parse is successful, resp.Name contains the normalized name being queried. -// TODO: stuffing the query name in resp.Name temporarily is a hack. Clean it up. -func parseQuery(query []byte, resp *response) error { - var parser dns.Parser - var err error +var dnsParserPool = &sync.Pool{ + New: func() interface{} { + return new(dnsParser) + }, +} - resp.Header, err = parser.Start(query) +// dnsParser parses DNS queries using x/net/dns/dnsmessage. +// These structs are pooled with dnsParserPool. +type dnsParser struct { + Header dns.Header + Question dns.Question + + parser dns.Parser +} + +func (p *dnsParser) response() *response { + return &response{Header: p.Header, Question: p.Question} +} + +// zeroParser clears parser so it doesn't retain its most recently +// parsed DNS query's []byte while it's sitting in a sync.Pool. +// It's not useful to keep anyway: the next Start will do the same. +func (p *dnsParser) zeroParser() { p.parser = dns.Parser{} } + +// parseQuery parses the query in given packet into p.Header and +// p.Question. +func (p *dnsParser) parseQuery(query []byte) error { + defer p.zeroParser() + var err error + p.Header, err = p.parser.Start(query) if err != nil { return err } - - if resp.Header.Response { + if p.Header.Response { return errNotQuery } - - resp.Question, err = parser.Question() - if err != nil { - return err - } - - return nil + p.Question, err = p.parser.Question() + return err } // marshalARecord serializes an A record into an active builder. @@ -624,12 +621,13 @@ func (r *Resolver) respondReverse(query []byte, name dnsname.FQDN, resp *respons // respond returns a DNS response to query if it can be resolved locally. // Otherwise, it returns errNotOurName. func (r *Resolver) respond(query []byte) ([]byte, error) { - resp := new(response) + parser := dnsParserPool.Get().(*dnsParser) + defer dnsParserPool.Put(parser) // ParseQuery is sufficiently fast to run on every DNS packet. // This is considerably simpler than extracting the name by hand // to shave off microseconds in case of delegation. - err := parseQuery(query, resp) + err := parser.parseQuery(query) // We will not return this error: it is the sender's fault. if err != nil { if errors.Is(err, dns.ErrSectionDone) { @@ -637,13 +635,15 @@ func (r *Resolver) respond(query []byte) ([]byte, error) { } else { r.logf("parseQuery(%02x): %v", query, err) } + resp := parser.response() resp.Header.RCode = dns.RCodeFormatError return marshalResponse(resp) } - rawName := resp.Question.Name.Data[:resp.Question.Name.Length] + rawName := parser.Question.Name.Data[:parser.Question.Name.Length] name, err := dnsname.ToFQDN(rawNameToLower(rawName)) if err != nil { // DNS packet unexpectedly contains an invalid FQDN. + resp := parser.response() resp.Header.RCode = dns.RCodeFormatError return marshalResponse(resp) } @@ -651,15 +651,17 @@ func (r *Resolver) respond(query []byte) ([]byte, error) { // Always try to handle reverse lookups; delegate inside when not found. // This way, queries for existent nodes do not leak, // but we behave gracefully if non-Tailscale nodes exist in CGNATRange. - if resp.Question.Type == dns.TypePTR { - return r.respondReverse(query, name, resp) + if parser.Question.Type == dns.TypePTR { + return r.respondReverse(query, name, parser.response()) } - resp.IP, resp.Header.RCode = r.resolveLocal(name, resp.Question.Type) - // This return code is special: it requests forwarding. - if resp.Header.RCode == dns.RCodeRefused { - return nil, errNotOurName + ip, rcode := r.resolveLocal(name, parser.Question.Type) + if rcode == dns.RCodeRefused { + return nil, errNotOurName // sentinel error return value: it requests forwarding } + resp := parser.response() + resp.Header.RCode = rcode + resp.IP = ip return marshalResponse(resp) } diff --git a/net/dns/resolver/tsdns_test.go b/net/dns/resolver/tsdns_test.go index bdfb559fb..857a8ba08 100644 --- a/net/dns/resolver/tsdns_test.go +++ b/net/dns/resolver/tsdns_test.go @@ -8,6 +8,7 @@ import ( "bytes" "encoding/hex" "errors" + "fmt" "math/rand" "net" "runtime" @@ -17,6 +18,7 @@ import ( "inet.af/netaddr" "tailscale.com/tstest" "tailscale.com/util/dnsname" + "tailscale.com/wgengine/monitor" ) var testipv4 = netaddr.MustParseIP("1.2.3.4") @@ -128,7 +130,9 @@ func unpackResponse(payload []byte) (dnsResponse, error) { } func syncRespond(r *Resolver, query []byte) ([]byte, error) { - r.EnqueueRequest(query, netaddr.IPPort{}) + if err := r.EnqueueRequest(query, netaddr.IPPort{}); err != nil { + return nil, fmt.Errorf("EnqueueRequest: %w", err) + } payload, _, err := r.NextResponse() return payload, err } @@ -211,8 +215,12 @@ func TestRDNSNameToIPv6(t *testing.T) { } } +func newResolver(t testing.TB) *Resolver { + return New(t.Logf, nil /* no link monitor */, nil /* no link selector */) +} + func TestResolveLocal(t *testing.T) { - r := New(t.Logf, nil) + r := newResolver(t) defer r.Close() r.SetConfig(dnsCfg) @@ -252,7 +260,7 @@ func TestResolveLocal(t *testing.T) { } func TestResolveLocalReverse(t *testing.T) { - r := New(t.Logf, nil) + r := newResolver(t) defer r.Close() r.SetConfig(dnsCfg) @@ -362,7 +370,7 @@ func TestDelegate(t *testing.T) { "huge.txt.", resolveToTXT(hugeTXT)) defer v6server.Shutdown() - r := New(t.Logf, nil) + r := newResolver(t) defer r.Close() cfg := dnsCfg @@ -474,7 +482,7 @@ func TestDelegateSplitRoute(t *testing.T) { "test.other.", resolveToIP(test4, test6, "dns.other.")) defer server2.Shutdown() - r := New(t.Logf, nil) + r := newResolver(t) defer r.Close() cfg := dnsCfg @@ -531,7 +539,7 @@ func TestDelegateCollision(t *testing.T) { "test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) defer server.Shutdown() - r := New(t.Logf, nil) + r := newResolver(t) defer r.Close() cfg := dnsCfg @@ -745,7 +753,7 @@ var emptyResponse = []byte{ } func TestFull(t *testing.T) { - r := New(t.Logf, nil) + r := newResolver(t) defer r.Close() r.SetConfig(dnsCfg) @@ -781,7 +789,7 @@ func TestFull(t *testing.T) { } func TestAllocs(t *testing.T) { - r := New(t.Logf, nil) + r := newResolver(t) defer r.Close() r.SetConfig(dnsCfg) @@ -835,7 +843,7 @@ func BenchmarkFull(b *testing.B) { "test.site.", resolveToIP(testipv4, testipv6, "dns.test.site.")) defer server.Shutdown() - r := New(b.Logf, nil) + r := newResolver(b) defer r.Close() cfg := dnsCfg @@ -872,3 +880,58 @@ func TestMarshalResponseFormatError(t *testing.T) { } t.Logf("response: %q", v) } + +func TestForwardLinkSelection(t *testing.T) { + old := initListenConfig + defer func() { initListenConfig = old }() + + configCall := make(chan string, 1) + initListenConfig = func(nc *net.ListenConfig, mon *monitor.Mon, tunName string) error { + select { + case configCall <- tunName: + return nil + default: + t.Error("buffer full") + return errors.New("buffer full") + } + } + + // specialIP is some IP we pretend that our link selector + // routes differently. + specialIP := netaddr.IPv4(1, 2, 3, 4) + + fwd := newForwarder(t.Logf, nil, nil, linkSelFunc(func(ip netaddr.IP) string { + if ip == netaddr.IPv4(1, 2, 3, 4) { + return "special" + } + return "" + })) + + // Test non-special IP. + if got, err := fwd.packetListener(netaddr.IP{}); err != nil { + t.Fatal(err) + } else if got != stdNetPacketListener { + t.Errorf("for IP zero value, didn't get expected packet listener") + } + select { + case v := <-configCall: + t.Errorf("unexpected ListenConfig call, with tunName %q", v) + default: + } + + // Test that our special IP generates a call to initListenConfig. + if got, err := fwd.packetListener(specialIP); err != nil { + t.Fatal(err) + } else if got == stdNetPacketListener { + t.Errorf("special IP returned std packet listener; expected unique one") + } + if v, ok := <-configCall; !ok { + t.Errorf("didn't get ListenConfig call") + } else if v != "special" { + t.Errorf("got tunName %q; want 'special'", v) + } +} + +type linkSelFunc func(ip netaddr.IP) string + +func (f linkSelFunc) PickLink(ip netaddr.IP) string { return f(ip) } diff --git a/net/netns/netns_macios.go b/net/netns/netns_macios.go new file mode 100644 index 000000000..e1baf4a2d --- /dev/null +++ b/net/netns/netns_macios.go @@ -0,0 +1,53 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin ios + +package netns + +import ( + "errors" + "log" + "net" + "strings" + "syscall" + + "golang.org/x/sys/unix" +) + +// SetListenConfigInterfaceIndex sets lc.Control such that sockets are bound +// to the provided interface index. +func SetListenConfigInterfaceIndex(lc *net.ListenConfig, ifIndex int) error { + if lc == nil { + return errors.New("nil ListenConfig") + } + if lc.Control != nil { + return errors.New("ListenConfig.Control already set") + } + lc.Control = func(network, address string, c syscall.RawConn) error { + var sockErr error + err := c.Control(func(fd uintptr) { + sockErr = bindInterface(fd, network, address, ifIndex) + if sockErr != nil { + log.Printf("netns: bind(%q, %q) on index %v: %v", network, address, ifIndex, sockErr) + } + }) + if err != nil { + return err + } + return sockErr + } + return nil +} + +func bindInterface(fd uintptr, network, address string, ifIndex int) error { + v6 := strings.Contains(address, "]:") || strings.HasSuffix(network, "6") // hacky test for v6 + proto := unix.IPPROTO_IP + opt := unix.IP_BOUND_IF + if v6 { + proto = unix.IPPROTO_IPV6 + opt = unix.IPV6_BOUND_IF + } + return unix.SetsockoptInt(int(fd), proto, opt, ifIndex) +} diff --git a/tstest/resource.go b/tstest/resource.go index 47828dea8..5b770986d 100644 --- a/tstest/resource.go +++ b/tstest/resource.go @@ -15,19 +15,19 @@ import ( ) func ResourceCheck(tb testing.TB) { + tb.Helper() startN, startStacks := goroutines() tb.Cleanup(func() { if tb.Failed() { // Something else went wrong. return } - tb.Helper() // Goroutines might be still exiting. for i := 0; i < 100; i++ { if runtime.NumGoroutine() <= startN { return } - time.Sleep(1 * time.Millisecond) + time.Sleep(5 * time.Millisecond) } endN, endStacks := goroutines() tb.Logf("goroutine diff:\n%v\n", cmp.Diff(startStacks, endStacks)) diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 493ae4a1d..a1209761f 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -101,6 +101,10 @@ type userspaceEngine struct { // incorrectly sent to us. isLocalAddr atomic.Value // of func(netaddr.IP)bool + // isDNSIPOverTailscale reports the whether a DNS resolver's IP + // is being routed over Tailscale. + isDNSIPOverTailscale atomic.Value // of func(netaddr.IP)bool + wgLock sync.Mutex // serializes all wgdev operations; see lock order comment below lastCfgFull wgcfg.Config lastRouterSig string // of router.Config @@ -242,6 +246,7 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) confListenPort: conf.ListenPort, } e.isLocalAddr.Store(tsaddr.NewContainsIPFunc(nil)) + e.isDNSIPOverTailscale.Store(tsaddr.NewContainsIPFunc(nil)) if conf.LinkMonitor != nil { e.linkMon = conf.LinkMonitor @@ -255,7 +260,8 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) e.linkMonOwned = true } - e.dns = dns.NewManager(logf, conf.DNS, e.linkMon) + tunName, _ := conf.Tun.Name() + e.dns = dns.NewManager(logf, conf.DNS, e.linkMon, fwdDNSLinkSelector{e, tunName}) logf("link state: %+v", e.linkMon.InterfaceState()) @@ -767,6 +773,13 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, return ErrNoChanges } + // TODO(bradfitz,danderson): maybe delete this isDNSIPOverTailscale + // field and delete the resolver.ForwardLinkSelector hook and + // instead have ipnlocal populate a map of DNS IP => linkName and + // put that in the *dns.Config instead, and plumb it down to the + // dns.Manager. Maybe also with isLocalAddr above. + e.isDNSIPOverTailscale.Store(tsaddr.NewContainsIPFunc(dnsIPsOverTailscale(dnsCfg, routerCfg))) + // See if any peers have changed disco keys, which means they've restarted. // If so, we need to update the wireguard-go/device.Device in two phases: // once without the node which has restarted, to clear its wireguard session key, @@ -1362,3 +1375,49 @@ func (p closeOnErrorPool) closeAllIfError(errp *error) { } } } + +// ipInPrefixes reports whether ip is in any of pp. +func ipInPrefixes(ip netaddr.IP, pp []netaddr.IPPrefix) bool { + for _, p := range pp { + if p.Contains(ip) { + return true + } + } + return false +} + +// dnsIPsOverTailscale returns the IPPrefixes of DNS resolver IPs that are +// routed over Tailscale. The returned value does not contain duplicates is +// not necessarily sorted. +func dnsIPsOverTailscale(dnsCfg *dns.Config, routerCfg *router.Config) (ret []netaddr.IPPrefix) { + m := map[netaddr.IP]bool{} + + for _, resolvers := range dnsCfg.Routes { + for _, resolver := range resolvers { + ip := resolver.IP() + if ipInPrefixes(ip, routerCfg.Routes) && !ipInPrefixes(ip, routerCfg.LocalRoutes) { + m[ip] = true + } + } + } + + ret = make([]netaddr.IPPrefix, 0, len(m)) + for ip := range m { + ret = append(ret, netaddr.IPPrefixFrom(ip, ip.BitLen())) + } + return ret +} + +// fwdDNSLinkSelector is userspaceEngine's resolver.ForwardLinkSelector, to pick +// which network interface to send DNS queries out of. +type fwdDNSLinkSelector struct { + ue *userspaceEngine + tunName string +} + +func (ls fwdDNSLinkSelector) PickLink(ip netaddr.IP) (linkName string) { + if ls.ue.isDNSIPOverTailscale.Load().(func(netaddr.IP) bool)(ip) { + return ls.tunName + } + return "" +}