mirror of
https://github.com/tailscale/tailscale.git
synced 2025-01-10 10:03:43 +00:00
c2144c44a3
We currently have two executions paths where (*forwarder).forwardWithDestChan returns nil, rather than an error, without sending a DNS response to responseChan. These paths are accompanied by a comment that reads: // Returning an error will cause an internal retry, there is // nothing we can do if parsing failed. Just drop the packet. But it is not (or no longer longer) accurate: returning an error from forwardWithDestChan does not currently cause a retry. Moreover, although these paths are currently unreachable due to implementation details, if (*forwarder).forwardWithDestChan were to return nil without sending a response to responseChan, it would cause a deadlock at one call site and a panic at another. Therefore, we update (*forwarder).forwardWithDestChan to return errors in those two paths and remove comments that were no longer accurate and misleading. Updates #cleanup Updates #13571 Signed-off-by: Nick Hill <mykola.khyl@gmail.com>
1196 lines
35 KiB
Go
1196 lines
35 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
package resolver
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"crypto/tls"
|
|
"encoding/base64"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/netip"
|
|
"net/url"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
dns "golang.org/x/net/dns/dnsmessage"
|
|
"tailscale.com/control/controlknobs"
|
|
"tailscale.com/envknob"
|
|
"tailscale.com/health"
|
|
"tailscale.com/net/dns/publicdns"
|
|
"tailscale.com/net/dnscache"
|
|
"tailscale.com/net/neterror"
|
|
"tailscale.com/net/netmon"
|
|
"tailscale.com/net/sockstats"
|
|
"tailscale.com/net/tsdial"
|
|
"tailscale.com/types/dnstype"
|
|
"tailscale.com/types/logger"
|
|
"tailscale.com/types/nettype"
|
|
"tailscale.com/util/cloudenv"
|
|
"tailscale.com/util/dnsname"
|
|
"tailscale.com/util/race"
|
|
"tailscale.com/version"
|
|
)
|
|
|
|
// headerBytes is the number of bytes in a DNS message header.
|
|
const headerBytes = 12
|
|
|
|
// dnsFlagTruncated is set in the flags word when the packet is truncated.
|
|
const dnsFlagTruncated = 0x200
|
|
|
|
// truncatedFlagSet returns true if the DNS packet signals that it has
|
|
// been truncated. False is also returned if the packet was too small
|
|
// to be valid.
|
|
func truncatedFlagSet(pkt []byte) bool {
|
|
if len(pkt) < headerBytes {
|
|
return false
|
|
}
|
|
return (binary.BigEndian.Uint16(pkt[2:4]) & dnsFlagTruncated) != 0
|
|
}
|
|
|
|
const (
|
|
// dohIdleConnTimeout is how long to keep idle HTTP connections
|
|
// open to DNS-over-HTTPS servers. 10 seconds is a sensible
|
|
// default, as it's long enough to handle a burst of queries
|
|
// coming in a row, but short enough to not keep idle connections
|
|
// open for too long. In theory, idle connections could be kept
|
|
// open for a long time without any battery impact as no traffic
|
|
// is supposed to be flowing on them.
|
|
// However, in practice, DoH servers will send TCP keepalives (e.g.
|
|
// NextDNS sends them every ~10s). Handling these keepalives
|
|
// wakes up the modem, and that uses battery. Therefore, we keep
|
|
// the idle timeout low enough to allow idle connections to be
|
|
// closed during an extended period with no DNS queries, killing
|
|
// keepalive network activity.
|
|
dohIdleConnTimeout = 10 * time.Second
|
|
|
|
// dohTransportTimeout is how much of a head start to give a DoH query
|
|
// that was upgraded from a well-known public DNS provider's IP before
|
|
// normal UDP mode is attempted as a fallback.
|
|
dohHeadStart = 500 * time.Millisecond
|
|
|
|
// wellKnownHostBackupDelay is how long to artificially delay upstream
|
|
// DNS queries to the "fallback" DNS server IP for a known provider
|
|
// (e.g. how long to wait to query Google's 8.8.4.4 after 8.8.8.8).
|
|
wellKnownHostBackupDelay = 200 * time.Millisecond
|
|
|
|
// udpRaceTimeout is the timeout after which we will start a DNS query
|
|
// over TCP while waiting for the UDP query to complete.
|
|
udpRaceTimeout = 2 * time.Second
|
|
|
|
// tcpQueryTimeout is the timeout for a DNS query performed over TCP.
|
|
// It matches the default 5sec timeout of the 'dig' utility.
|
|
tcpQueryTimeout = 5 * time.Second
|
|
)
|
|
|
|
// txid identifies a DNS transaction.
|
|
//
|
|
// As the standard DNS Request ID is only 16 bits, we extend it:
|
|
// the lower 32 bits are the zero-extended bits of the DNS Request ID;
|
|
// the upper 32 bits are the CRC32 checksum of the first question in the request.
|
|
// This makes probability of txid collision negligible.
|
|
type txid uint64
|
|
|
|
// getTxID computes the txid of the given DNS packet.
|
|
func getTxID(packet []byte) txid {
|
|
if len(packet) < headerBytes {
|
|
return 0
|
|
}
|
|
|
|
dnsid := binary.BigEndian.Uint16(packet[0:2])
|
|
// Previously, we hashed the question and combined it with the original txid
|
|
// which was useful when concurrent queries were multiplexed on a single
|
|
// local source port. We encountered some situations where the DNS server
|
|
// canonicalizes the question in the response (uppercase converted to
|
|
// lowercase in this case), which resulted in responses that we couldn't
|
|
// match to the original request due to hash mismatches.
|
|
return txid(dnsid)
|
|
}
|
|
|
|
func getRCode(packet []byte) dns.RCode {
|
|
if len(packet) < headerBytes {
|
|
// treat invalid packets as a refusal
|
|
return dns.RCode(5)
|
|
}
|
|
// get bottom 4 bits of 3rd byte
|
|
return dns.RCode(packet[3] & 0x0F)
|
|
}
|
|
|
|
// clampEDNSSize attempts to limit the maximum EDNS response size. This is not
|
|
// an exhaustive solution, instead only easy cases are currently handled in the
|
|
// interest of speed and reduced complexity. Only OPT records at the very end of
|
|
// the message with no option codes are addressed.
|
|
// TODO: handle more situations if we discover that they happen often
|
|
func clampEDNSSize(packet []byte, maxSize uint16) {
|
|
// optFixedBytes is the size of an OPT record with no option codes.
|
|
const optFixedBytes = 11
|
|
const edns0Version = 0
|
|
|
|
if len(packet) < headerBytes+optFixedBytes {
|
|
return
|
|
}
|
|
|
|
arCount := binary.BigEndian.Uint16(packet[10:12])
|
|
if arCount == 0 {
|
|
// OPT shows up in an AR, so there must be no OPT
|
|
return
|
|
}
|
|
|
|
// https://datatracker.ietf.org/doc/html/rfc6891#section-6.1.2
|
|
opt := packet[len(packet)-optFixedBytes:]
|
|
|
|
if opt[0] != 0 {
|
|
// OPT NAME must be 0 (root domain)
|
|
return
|
|
}
|
|
if dns.Type(binary.BigEndian.Uint16(opt[1:3])) != dns.TypeOPT {
|
|
// Not an OPT record
|
|
return
|
|
}
|
|
requestedSize := binary.BigEndian.Uint16(opt[3:5])
|
|
// Ignore extended RCODE in opt[5]
|
|
if opt[6] != edns0Version {
|
|
// Be conservative and don't touch unknown versions.
|
|
return
|
|
}
|
|
// Ignore flags in opt[6:9]
|
|
if binary.BigEndian.Uint16(opt[9:11]) != 0 {
|
|
// RDLEN must be 0 (no variable length data). We're at the end of the
|
|
// packet so this should be 0 anyway)..
|
|
return
|
|
}
|
|
|
|
if requestedSize <= maxSize {
|
|
return
|
|
}
|
|
|
|
// Clamp the maximum size
|
|
binary.BigEndian.PutUint16(opt[3:5], maxSize)
|
|
}
|
|
|
|
// dnsForwarderFailing should be raised when the forwarder is unable to reach the
|
|
// upstream resolvers. This is a high severity warning as it results in "no internet".
|
|
// This warning must be cleared when the forwarder is working again.
|
|
//
|
|
// We allow for 5 second grace period to ensure this is not raised for spurious errors
|
|
// under the assumption that DNS queries are relatively frequent and a subsequent
|
|
// successful query will clear any one-off errors.
|
|
var dnsForwarderFailing = health.Register(&health.Warnable{
|
|
Code: "dns-forward-failing",
|
|
Title: "DNS unavailable",
|
|
Severity: health.SeverityMedium,
|
|
DependsOn: []*health.Warnable{health.NetworkStatusWarnable},
|
|
Text: health.StaticMessage("Tailscale can't reach the configured DNS servers. Internet connectivity may be affected."),
|
|
ImpactsConnectivity: true,
|
|
TimeToVisible: 15 * time.Second,
|
|
})
|
|
|
|
type route struct {
|
|
Suffix dnsname.FQDN
|
|
Resolvers []resolverAndDelay
|
|
}
|
|
|
|
// resolverAndDelay is an upstream DNS resolver and a delay for how
|
|
// long to wait before querying it.
|
|
type resolverAndDelay struct {
|
|
// name is the upstream resolver.
|
|
name *dnstype.Resolver
|
|
|
|
// startDelay is an amount to delay this resolver at
|
|
// start. It's used when, say, there are four Google or
|
|
// Cloudflare DNS IPs (two IPv4 + two IPv6) and we don't want
|
|
// to race all four at once.
|
|
startDelay time.Duration
|
|
}
|
|
|
|
// forwarder forwards DNS packets to a number of upstream nameservers.
|
|
type forwarder struct {
|
|
logf logger.Logf
|
|
netMon *netmon.Monitor // always non-nil
|
|
linkSel ForwardLinkSelector // TODO(bradfitz): remove this when tsdial.Dialer absorbs it
|
|
dialer *tsdial.Dialer
|
|
health *health.Tracker // always non-nil
|
|
|
|
controlKnobs *controlknobs.Knobs // or nil
|
|
|
|
ctx context.Context // good until Close
|
|
ctxCancel context.CancelFunc // closes ctx
|
|
|
|
mu sync.Mutex // guards following
|
|
|
|
dohClient map[string]*http.Client // urlBase -> client
|
|
|
|
// routes are per-suffix resolvers to use, with
|
|
// the most specific routes first.
|
|
routes []route
|
|
// cloudHostFallback are last resort resolvers to use if no per-suffix
|
|
// resolver matches. These are only populated on cloud hosts where the
|
|
// platform provides a well-known recursive resolver.
|
|
//
|
|
// That is, if we're running on GCP or AWS where there's always a well-known
|
|
// IP of a recursive resolver, return that rather than having callers return
|
|
// SERVFAIL. This fixes both normal 100.100.100.100 resolution when
|
|
// /etc/resolv.conf is missing/corrupt, and the peerapi ExitDNS stub
|
|
// resolver lookup.
|
|
cloudHostFallback []resolverAndDelay
|
|
|
|
// missingUpstreamRecovery, if non-nil, is set called when a SERVFAIL is
|
|
// returned due to missing upstream resolvers.
|
|
//
|
|
// This should attempt to properly (re)set the upstream resolvers.
|
|
missingUpstreamRecovery func()
|
|
}
|
|
|
|
func newForwarder(logf logger.Logf, netMon *netmon.Monitor, linkSel ForwardLinkSelector, dialer *tsdial.Dialer, health *health.Tracker, knobs *controlknobs.Knobs) *forwarder {
|
|
if netMon == nil {
|
|
panic("nil netMon")
|
|
}
|
|
f := &forwarder{
|
|
logf: logger.WithPrefix(logf, "forward: "),
|
|
netMon: netMon,
|
|
linkSel: linkSel,
|
|
dialer: dialer,
|
|
health: health,
|
|
controlKnobs: knobs,
|
|
missingUpstreamRecovery: func() {},
|
|
}
|
|
f.ctx, f.ctxCancel = context.WithCancel(context.Background())
|
|
return f
|
|
}
|
|
|
|
func (f *forwarder) Close() error {
|
|
f.ctxCancel()
|
|
return nil
|
|
}
|
|
|
|
// resolversWithDelays maps from a set of DNS server names to a slice of a type
|
|
// that included a startDelay, upgrading any well-known DoH (DNS-over-HTTP)
|
|
// servers in the process, insert a DoH lookup first before UDP fallbacks.
|
|
func resolversWithDelays(resolvers []*dnstype.Resolver) []resolverAndDelay {
|
|
rr := make([]resolverAndDelay, 0, len(resolvers)+2)
|
|
|
|
type dohState uint8
|
|
const addedDoH = dohState(1)
|
|
const addedDoHAndDontAddUDP = dohState(2)
|
|
|
|
// Add the known DoH ones first, starting immediately.
|
|
didDoH := map[string]dohState{}
|
|
for _, r := range resolvers {
|
|
ipp, ok := r.IPPort()
|
|
if !ok {
|
|
continue
|
|
}
|
|
dohBase, dohOnly, ok := publicdns.DoHEndpointFromIP(ipp.Addr())
|
|
if !ok || didDoH[dohBase] != 0 {
|
|
continue
|
|
}
|
|
if dohOnly {
|
|
didDoH[dohBase] = addedDoHAndDontAddUDP
|
|
} else {
|
|
didDoH[dohBase] = addedDoH
|
|
}
|
|
rr = append(rr, resolverAndDelay{name: &dnstype.Resolver{Addr: dohBase}})
|
|
}
|
|
|
|
type hostAndFam struct {
|
|
host string // some arbitrary string representing DNS host (currently the DoH base)
|
|
bits uint8 // either 32 or 128 for IPv4 vs IPv6s address family
|
|
}
|
|
done := map[hostAndFam]int{}
|
|
for _, r := range resolvers {
|
|
ipp, ok := r.IPPort()
|
|
if !ok {
|
|
// Pass non-IP ones through unchanged, without delay.
|
|
// (e.g. DNS-over-ExitDNS when using an exit node)
|
|
rr = append(rr, resolverAndDelay{name: r})
|
|
continue
|
|
}
|
|
ip := ipp.Addr()
|
|
var startDelay time.Duration
|
|
if host, _, ok := publicdns.DoHEndpointFromIP(ip); ok {
|
|
if didDoH[host] == addedDoHAndDontAddUDP {
|
|
continue
|
|
}
|
|
// We already did the DoH query early. These
|
|
// are for normal dns53 UDP queries.
|
|
startDelay = dohHeadStart
|
|
key := hostAndFam{host, uint8(ip.BitLen())}
|
|
if done[key] > 0 {
|
|
startDelay += wellKnownHostBackupDelay
|
|
}
|
|
done[key]++
|
|
}
|
|
rr = append(rr, resolverAndDelay{
|
|
name: r,
|
|
startDelay: startDelay,
|
|
})
|
|
}
|
|
return rr
|
|
}
|
|
|
|
var (
|
|
cloudResolversOnce sync.Once
|
|
cloudResolversLazy []resolverAndDelay
|
|
)
|
|
|
|
func cloudResolvers() []resolverAndDelay {
|
|
cloudResolversOnce.Do(func() {
|
|
if ip := cloudenv.Get().ResolverIP(); ip != "" {
|
|
cloudResolver := []*dnstype.Resolver{{Addr: ip}}
|
|
cloudResolversLazy = resolversWithDelays(cloudResolver)
|
|
}
|
|
})
|
|
return cloudResolversLazy
|
|
}
|
|
|
|
// setRoutes sets the routes to use for DNS forwarding. It's called by
|
|
// Resolver.SetConfig on reconfig.
|
|
//
|
|
// The memory referenced by routesBySuffix should not be modified.
|
|
func (f *forwarder) setRoutes(routesBySuffix map[dnsname.FQDN][]*dnstype.Resolver) {
|
|
routes := make([]route, 0, len(routesBySuffix))
|
|
|
|
cloudHostFallback := cloudResolvers()
|
|
for suffix, rs := range routesBySuffix {
|
|
if suffix == "." && len(rs) == 0 && len(cloudHostFallback) > 0 {
|
|
routes = append(routes, route{
|
|
Suffix: suffix,
|
|
Resolvers: cloudHostFallback,
|
|
})
|
|
} else {
|
|
routes = append(routes, route{
|
|
Suffix: suffix,
|
|
Resolvers: resolversWithDelays(rs),
|
|
})
|
|
}
|
|
}
|
|
|
|
if cloudenv.Get().HasInternalTLD() && len(cloudHostFallback) > 0 {
|
|
if _, ok := routesBySuffix["internal."]; !ok {
|
|
routes = append(routes, route{
|
|
Suffix: "internal.",
|
|
Resolvers: cloudHostFallback,
|
|
})
|
|
}
|
|
}
|
|
|
|
// Sort from longest prefix to shortest.
|
|
sort.Slice(routes, func(i, j int) bool {
|
|
return routes[i].Suffix.NumLabels() > routes[j].Suffix.NumLabels()
|
|
})
|
|
|
|
f.mu.Lock()
|
|
defer f.mu.Unlock()
|
|
f.routes = routes
|
|
f.cloudHostFallback = cloudHostFallback
|
|
}
|
|
|
|
var stdNetPacketListener nettype.PacketListenerWithNetIP = nettype.MakePacketListenerWithNetIP(new(net.ListenConfig))
|
|
|
|
func (f *forwarder) packetListener(ip netip.Addr) (nettype.PacketListenerWithNetIP, error) {
|
|
if f.linkSel == nil || initListenConfig == nil {
|
|
return stdNetPacketListener, nil
|
|
}
|
|
linkName := f.linkSel.PickLink(ip)
|
|
if linkName == "" {
|
|
return stdNetPacketListener, nil
|
|
}
|
|
lc := new(net.ListenConfig)
|
|
if err := initListenConfig(lc, f.netMon, linkName); err != nil {
|
|
return nil, err
|
|
}
|
|
return nettype.MakePacketListenerWithNetIP(lc), nil
|
|
}
|
|
|
|
// getKnownDoHClientForProvider returns an HTTP client for a specific DoH
|
|
// provider named by its DoH base URL (like "https://dns.google/dns-query").
|
|
//
|
|
// The returned client race/Happy Eyeballs dials all IPs for urlBase (usually
|
|
// 4), as statically known by the publicdns package.
|
|
func (f *forwarder) getKnownDoHClientForProvider(urlBase string) (c *http.Client, ok bool) {
|
|
f.mu.Lock()
|
|
defer f.mu.Unlock()
|
|
if c, ok := f.dohClient[urlBase]; ok {
|
|
return c, true
|
|
}
|
|
allIPs := publicdns.DoHIPsOfBase(urlBase)
|
|
if len(allIPs) == 0 {
|
|
return nil, false
|
|
}
|
|
dohURL, err := url.Parse(urlBase)
|
|
if err != nil {
|
|
return nil, false
|
|
}
|
|
|
|
dialer := dnscache.Dialer(f.getDialerType(), &dnscache.Resolver{
|
|
SingleHost: dohURL.Hostname(),
|
|
SingleHostStaticResult: allIPs,
|
|
Logf: f.logf,
|
|
})
|
|
tlsConfig := &tls.Config{
|
|
// Enforce TLS 1.3, as all of our supported DNS-over-HTTPS servers are compatible with it
|
|
// (see tailscale.com/net/dns/publicdns/publicdns.go).
|
|
MinVersion: tls.VersionTLS13,
|
|
}
|
|
c = &http.Client{
|
|
Transport: &http.Transport{
|
|
ForceAttemptHTTP2: true,
|
|
IdleConnTimeout: dohIdleConnTimeout,
|
|
// On mobile platforms TCP KeepAlive is disabled in the dialer,
|
|
// ensure that we timeout if the connection appears to be hung.
|
|
ResponseHeaderTimeout: 10 * time.Second,
|
|
MaxIdleConnsPerHost: 1,
|
|
DialContext: func(ctx context.Context, netw, addr string) (net.Conn, error) {
|
|
if !strings.HasPrefix(netw, "tcp") {
|
|
return nil, fmt.Errorf("unexpected network %q", netw)
|
|
}
|
|
return dialer(ctx, netw, addr)
|
|
},
|
|
TLSClientConfig: tlsConfig,
|
|
},
|
|
}
|
|
if f.dohClient == nil {
|
|
f.dohClient = map[string]*http.Client{}
|
|
}
|
|
f.dohClient[urlBase] = c
|
|
return c, true
|
|
}
|
|
|
|
const dohType = "application/dns-message"
|
|
|
|
func (f *forwarder) sendDoH(ctx context.Context, urlBase string, c *http.Client, packet []byte) ([]byte, error) {
|
|
ctx = sockstats.WithSockStats(ctx, sockstats.LabelDNSForwarderDoH, f.logf)
|
|
metricDNSFwdDoH.Add(1)
|
|
req, err := http.NewRequestWithContext(ctx, "POST", urlBase, bytes.NewReader(packet))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Content-Type", dohType)
|
|
req.Header.Set("Accept", dohType)
|
|
req.Header.Set("User-Agent", "tailscaled/"+version.Long())
|
|
|
|
hres, err := c.Do(req)
|
|
if err != nil {
|
|
metricDNSFwdDoHErrorTransport.Add(1)
|
|
return nil, err
|
|
}
|
|
defer hres.Body.Close()
|
|
if hres.StatusCode != 200 {
|
|
metricDNSFwdDoHErrorStatus.Add(1)
|
|
if hres.StatusCode/100 == 5 {
|
|
// Translate 5xx HTTP server errors into SERVFAIL DNS responses.
|
|
return nil, fmt.Errorf("%w: %s", errServerFailure, hres.Status)
|
|
}
|
|
return nil, errors.New(hres.Status)
|
|
}
|
|
if ct := hres.Header.Get("Content-Type"); ct != dohType {
|
|
metricDNSFwdDoHErrorCT.Add(1)
|
|
return nil, fmt.Errorf("unexpected response Content-Type %q", ct)
|
|
}
|
|
res, err := io.ReadAll(hres.Body)
|
|
if err != nil {
|
|
metricDNSFwdDoHErrorBody.Add(1)
|
|
}
|
|
if truncatedFlagSet(res) {
|
|
metricDNSFwdTruncated.Add(1)
|
|
}
|
|
return res, err
|
|
}
|
|
|
|
var (
|
|
verboseDNSForward = envknob.RegisterBool("TS_DEBUG_DNS_FORWARD_SEND")
|
|
skipTCPRetry = envknob.RegisterBool("TS_DNS_FORWARD_SKIP_TCP_RETRY")
|
|
|
|
// For correlating log messages in the send() function; only used when
|
|
// verboseDNSForward() is true.
|
|
forwarderCount atomic.Uint64
|
|
)
|
|
|
|
// send sends packet to dst. It is best effort.
|
|
//
|
|
// send expects the reply to have the same txid as txidOut.
|
|
func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDelay) (ret []byte, err error) {
|
|
if verboseDNSForward() {
|
|
id := forwarderCount.Add(1)
|
|
domain, typ, _ := nameFromQuery(fq.packet)
|
|
f.logf("forwarder.send(%q, %d, %v, %d) [%d] ...", rr.name.Addr, fq.txid, typ, len(domain), id)
|
|
defer func() {
|
|
f.logf("forwarder.send(%q, %d, %v, %d) [%d] = %v, %v", rr.name.Addr, fq.txid, typ, len(domain), id, len(ret), err)
|
|
}()
|
|
}
|
|
if strings.HasPrefix(rr.name.Addr, "http://") {
|
|
return f.sendDoH(ctx, rr.name.Addr, f.dialer.PeerAPIHTTPClient(), fq.packet)
|
|
}
|
|
if strings.HasPrefix(rr.name.Addr, "https://") {
|
|
// Only known DoH providers are supported currently. Specifically, we
|
|
// only support DoH providers where we can TCP connect to them on port
|
|
// 443 at the same IP address they serve normal UDP DNS from (1.1.1.1,
|
|
// 8.8.8.8, 9.9.9.9, etc.) That's why OpenDNS and custom DoH providers
|
|
// aren't currently supported. There's no backup DNS resolution path for
|
|
// them.
|
|
urlBase := rr.name.Addr
|
|
if hc, ok := f.getKnownDoHClientForProvider(urlBase); ok {
|
|
return f.sendDoH(ctx, urlBase, hc, fq.packet)
|
|
}
|
|
metricDNSFwdErrorType.Add(1)
|
|
return nil, fmt.Errorf("arbitrary https:// resolvers not supported yet")
|
|
}
|
|
if strings.HasPrefix(rr.name.Addr, "tls://") {
|
|
metricDNSFwdErrorType.Add(1)
|
|
return nil, fmt.Errorf("tls:// resolvers not supported yet")
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
isUDPQuery := fq.family == "udp"
|
|
skipTCP := skipTCPRetry() || (f.controlKnobs != nil && f.controlKnobs.DisableDNSForwarderTCPRetries.Load())
|
|
|
|
// Print logs about retries if this was because of a truncated response.
|
|
var explicitRetry atomic.Bool // true if truncated UDP response retried
|
|
defer func() {
|
|
if !explicitRetry.Load() {
|
|
return
|
|
}
|
|
if err == nil {
|
|
f.logf("forwarder.send(%q): successfully retried via TCP", rr.name.Addr)
|
|
} else {
|
|
f.logf("forwarder.send(%q): could not retry via TCP: %v", rr.name.Addr, err)
|
|
}
|
|
}()
|
|
|
|
firstUDP := func(ctx context.Context) ([]byte, error) {
|
|
resp, err := f.sendUDP(ctx, fq, rr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !truncatedFlagSet(resp) {
|
|
// Successful, non-truncated response; no retry.
|
|
return resp, nil
|
|
}
|
|
|
|
// If this is a UDP query, return it regardless of whether the
|
|
// response is truncated or not; the client can retry
|
|
// communicating with tailscaled over TCP. There's no point
|
|
// falling back to TCP for a truncated query if we can't return
|
|
// the results to the client.
|
|
if isUDPQuery {
|
|
return resp, nil
|
|
}
|
|
|
|
if skipTCP {
|
|
// Envknob or control knob disabled the TCP retry behaviour;
|
|
// just return what we have.
|
|
return resp, nil
|
|
}
|
|
|
|
// This is a TCP query from the client, and the UDP response
|
|
// from the upstream DNS server is truncated; map this to an
|
|
// error to cause our retry helper to immediately kick off the
|
|
// TCP retry.
|
|
explicitRetry.Store(true)
|
|
return nil, truncatedResponseError{resp}
|
|
}
|
|
thenTCP := func(ctx context.Context) ([]byte, error) {
|
|
// If we're skipping the TCP fallback, then wait until the
|
|
// context is canceled and return that error (i.e. not
|
|
// returning anything).
|
|
if skipTCP {
|
|
<-ctx.Done()
|
|
return nil, ctx.Err()
|
|
}
|
|
|
|
return f.sendTCP(ctx, fq, rr)
|
|
}
|
|
|
|
// If the input query is TCP, then don't have a timeout between
|
|
// starting UDP and TCP.
|
|
timeout := udpRaceTimeout
|
|
if !isUDPQuery {
|
|
timeout = 0
|
|
}
|
|
|
|
// Kick off the race between the UDP and TCP queries.
|
|
rh := race.New(timeout, firstUDP, thenTCP)
|
|
resp, err := rh.Start(ctx)
|
|
if err == nil {
|
|
return resp, nil
|
|
}
|
|
|
|
// If we got a truncated UDP response, return that instead of an error.
|
|
var trErr truncatedResponseError
|
|
if errors.As(err, &trErr) {
|
|
return trErr.res, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
type truncatedResponseError struct {
|
|
res []byte
|
|
}
|
|
|
|
func (tr truncatedResponseError) Error() string { return "response truncated" }
|
|
|
|
var errServerFailure = errors.New("response code indicates server issue")
|
|
var errTxIDMismatch = errors.New("txid doesn't match")
|
|
|
|
func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAndDelay) (ret []byte, err error) {
|
|
ipp, ok := rr.name.IPPort()
|
|
if !ok {
|
|
metricDNSFwdErrorType.Add(1)
|
|
return nil, fmt.Errorf("unrecognized resolver type %q", rr.name.Addr)
|
|
}
|
|
metricDNSFwdUDP.Add(1)
|
|
ctx = sockstats.WithSockStats(ctx, sockstats.LabelDNSForwarderUDP, f.logf)
|
|
|
|
ln, err := f.packetListener(ipp.Addr())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Specify the exact UDP family to work around https://github.com/golang/go/issues/52264
|
|
udpFam := "udp4"
|
|
if ipp.Addr().Is6() {
|
|
udpFam = "udp6"
|
|
}
|
|
conn, err := ln.ListenPacket(ctx, udpFam, ":0")
|
|
if err != nil {
|
|
f.logf("ListenPacket failed: %v", err)
|
|
return nil, err
|
|
}
|
|
defer conn.Close()
|
|
|
|
fq.closeOnCtxDone.Add(conn)
|
|
defer fq.closeOnCtxDone.Remove(conn)
|
|
|
|
if _, err := conn.WriteToUDPAddrPort(fq.packet, ipp); err != nil {
|
|
metricDNSFwdUDPErrorWrite.Add(1)
|
|
if err := ctx.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return nil, err
|
|
}
|
|
metricDNSFwdUDPWrote.Add(1)
|
|
|
|
// The 1 extra byte is to detect packet truncation.
|
|
out := make([]byte, maxResponseBytes+1)
|
|
n, _, err := conn.ReadFromUDPAddrPort(out)
|
|
if err != nil {
|
|
if err := ctx.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
if neterror.PacketWasTruncated(err) {
|
|
err = nil
|
|
} else {
|
|
metricDNSFwdUDPErrorRead.Add(1)
|
|
return nil, err
|
|
}
|
|
}
|
|
truncated := n > maxResponseBytes
|
|
if truncated {
|
|
n = maxResponseBytes
|
|
}
|
|
if n < headerBytes {
|
|
f.logf("recv: packet too small (%d bytes)", n)
|
|
}
|
|
out = out[:n]
|
|
txid := getTxID(out)
|
|
if txid != fq.txid {
|
|
metricDNSFwdUDPErrorTxID.Add(1)
|
|
return nil, errTxIDMismatch
|
|
}
|
|
rcode := getRCode(out)
|
|
// don't forward transient errors back to the client when the server fails
|
|
if rcode == dns.RCodeServerFailure {
|
|
f.logf("recv: response code indicating server failure: %d", rcode)
|
|
metricDNSFwdUDPErrorServer.Add(1)
|
|
return nil, errServerFailure
|
|
}
|
|
|
|
if truncated {
|
|
// Set the truncated bit if it wasn't already.
|
|
flags := binary.BigEndian.Uint16(out[2:4])
|
|
flags |= dnsFlagTruncated
|
|
binary.BigEndian.PutUint16(out[2:4], flags)
|
|
|
|
// TODO(#2067): Remove any incomplete records? RFC 1035 section 6.2
|
|
// states that truncation should head drop so that the authority
|
|
// section can be preserved if possible. However, the UDP read with
|
|
// a too-small buffer has already dropped the end, so that's the
|
|
// best we can do.
|
|
}
|
|
|
|
if truncatedFlagSet(out) {
|
|
metricDNSFwdTruncated.Add(1)
|
|
}
|
|
|
|
clampEDNSSize(out, maxResponseBytes)
|
|
metricDNSFwdUDPSuccess.Add(1)
|
|
return out, nil
|
|
}
|
|
|
|
func (f *forwarder) getDialerType() dnscache.DialContextFunc {
|
|
if f.controlKnobs != nil && f.controlKnobs.UserDialUseRoutes.Load() {
|
|
// It is safe to use UserDial as it dials external servers without going through Tailscale
|
|
// and closes connections on interface change in the same way as SystemDial does,
|
|
// thus preventing DNS resolution issues when switching between WiFi and cellular,
|
|
// but can also dial an internal DNS server on the Tailnet or via a subnet router.
|
|
//
|
|
// TODO(nickkhyl): Update tsdial.Dialer to reuse the bart.Table we create in net/tstun.Wrapper
|
|
// to avoid having two bart tables in memory, especially on iOS. Once that's done,
|
|
// we can get rid of the nodeAttr/control knob and always use UserDial for DNS.
|
|
//
|
|
// See https://github.com/tailscale/tailscale/issues/12027.
|
|
return f.dialer.UserDial
|
|
}
|
|
return f.dialer.SystemDial
|
|
}
|
|
|
|
func (f *forwarder) sendTCP(ctx context.Context, fq *forwardQuery, rr resolverAndDelay) (ret []byte, err error) {
|
|
ipp, ok := rr.name.IPPort()
|
|
if !ok {
|
|
metricDNSFwdErrorType.Add(1)
|
|
return nil, fmt.Errorf("unrecognized resolver type %q", rr.name.Addr)
|
|
}
|
|
metricDNSFwdTCP.Add(1)
|
|
ctx = sockstats.WithSockStats(ctx, sockstats.LabelDNSForwarderTCP, f.logf)
|
|
|
|
// Specify the exact family to work around https://github.com/golang/go/issues/52264
|
|
tcpFam := "tcp4"
|
|
if ipp.Addr().Is6() {
|
|
tcpFam = "tcp6"
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, tcpQueryTimeout)
|
|
defer cancel()
|
|
|
|
conn, err := f.getDialerType()(ctx, tcpFam, ipp.String())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer conn.Close()
|
|
|
|
fq.closeOnCtxDone.Add(conn)
|
|
defer fq.closeOnCtxDone.Remove(conn)
|
|
|
|
ctxOrErr := func(err2 error) ([]byte, error) {
|
|
if err := ctx.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return nil, err2
|
|
}
|
|
|
|
// Write the query to the server.
|
|
query := make([]byte, len(fq.packet)+2)
|
|
binary.BigEndian.PutUint16(query, uint16(len(fq.packet)))
|
|
copy(query[2:], fq.packet)
|
|
if _, err := conn.Write(query); err != nil {
|
|
metricDNSFwdTCPErrorWrite.Add(1)
|
|
return ctxOrErr(err)
|
|
}
|
|
|
|
metricDNSFwdTCPWrote.Add(1)
|
|
|
|
// Read the header length back from the server
|
|
var length uint16
|
|
if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
|
|
metricDNSFwdTCPErrorRead.Add(1)
|
|
return ctxOrErr(err)
|
|
}
|
|
|
|
// Now read the response
|
|
out := make([]byte, length)
|
|
n, err := io.ReadFull(conn, out)
|
|
if err != nil {
|
|
metricDNSFwdTCPErrorRead.Add(1)
|
|
return ctxOrErr(err)
|
|
}
|
|
|
|
if n < int(length) {
|
|
f.logf("sendTCP: packet too small (%d bytes)", n)
|
|
return nil, io.ErrUnexpectedEOF
|
|
}
|
|
out = out[:n]
|
|
txid := getTxID(out)
|
|
if txid != fq.txid {
|
|
metricDNSFwdTCPErrorTxID.Add(1)
|
|
return nil, errTxIDMismatch
|
|
}
|
|
|
|
rcode := getRCode(out)
|
|
|
|
// don't forward transient errors back to the client when the server fails
|
|
if rcode == dns.RCodeServerFailure {
|
|
f.logf("sendTCP: response code indicating server failure: %d", rcode)
|
|
metricDNSFwdTCPErrorServer.Add(1)
|
|
return nil, errServerFailure
|
|
}
|
|
|
|
// TODO(andrew): do we need to do this?
|
|
//clampEDNSSize(out, maxResponseBytes)
|
|
metricDNSFwdTCPSuccess.Add(1)
|
|
return out, nil
|
|
}
|
|
|
|
// resolvers returns the resolvers to use for domain.
|
|
func (f *forwarder) resolvers(domain dnsname.FQDN) []resolverAndDelay {
|
|
f.mu.Lock()
|
|
routes := f.routes
|
|
cloudHostFallback := f.cloudHostFallback
|
|
f.mu.Unlock()
|
|
for _, route := range routes {
|
|
if route.Suffix == "." || route.Suffix.Contains(domain) {
|
|
return route.Resolvers
|
|
}
|
|
}
|
|
return cloudHostFallback // or nil if no fallback
|
|
}
|
|
|
|
// GetUpstreamResolvers returns the resolvers that would be used to resolve
|
|
// the given FQDN.
|
|
func (f *forwarder) GetUpstreamResolvers(name dnsname.FQDN) []*dnstype.Resolver {
|
|
resolvers := f.resolvers(name)
|
|
upstreamResolvers := make([]*dnstype.Resolver, 0, len(resolvers))
|
|
for _, r := range resolvers {
|
|
upstreamResolvers = append(upstreamResolvers, r.name)
|
|
}
|
|
return upstreamResolvers
|
|
}
|
|
|
|
// forwardQuery is information and state about a forwarded DNS query that's
|
|
// being sent to 1 or more upstreams.
|
|
//
|
|
// In the case of racing against multiple equivalent upstreams
|
|
// (e.g. Google or CloudFlare's 4 DNS IPs: 2 IPv4 + 2 IPv6), this type
|
|
// handles racing them more intelligently than just blasting away 4
|
|
// queries at once.
|
|
type forwardQuery struct {
|
|
txid txid
|
|
packet []byte
|
|
family string // "tcp" or "udp"
|
|
|
|
// closeOnCtxDone lets send register values to Close if the
|
|
// caller's ctx expires. This avoids send from allocating its
|
|
// own waiting goroutine to interrupt the ReadFrom, as memory
|
|
// is tight on iOS and we want the number of pending DNS
|
|
// lookups to be bursty without too much associated
|
|
// goroutine/memory cost.
|
|
closeOnCtxDone *closePool
|
|
|
|
// TODO(bradfitz): add race delay state:
|
|
// mu sync.Mutex
|
|
// ...
|
|
}
|
|
|
|
// forwardWithDestChan forwards the query to all upstream nameservers
|
|
// and waits for the first response.
|
|
//
|
|
// It either sends to responseChan and returns nil, or returns a
|
|
// non-nil error (without sending to the channel).
|
|
//
|
|
// If resolvers is non-empty, it's used explicitly (notably, for exit
|
|
// node DNS proxy queries), otherwise f.resolvers is used.
|
|
func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, responseChan chan<- packet, resolvers ...resolverAndDelay) error {
|
|
metricDNSFwd.Add(1)
|
|
domain, typ, err := nameFromQuery(query.bs)
|
|
if err != nil {
|
|
metricDNSFwdErrorName.Add(1)
|
|
return err
|
|
}
|
|
|
|
// Guarantee that the ctx we use below is done when this function returns.
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
// Drop DNS service discovery spam, primarily for battery life
|
|
// on mobile. Things like Spotify on iOS generate this traffic,
|
|
// when browsing for LAN devices. But even when filtering this
|
|
// out, playing on Sonos still works.
|
|
if hasRDNSBonjourPrefix(domain) {
|
|
metricDNSFwdDropBonjour.Add(1)
|
|
res, err := nxDomainResponse(query)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return fmt.Errorf("waiting to send NXDOMAIN: %w", ctx.Err())
|
|
case responseChan <- res:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
if fl := fwdLogAtomic.Load(); fl != nil {
|
|
fl.addName(string(domain))
|
|
}
|
|
|
|
clampEDNSSize(query.bs, maxResponseBytes)
|
|
|
|
if len(resolvers) == 0 {
|
|
resolvers = f.resolvers(domain)
|
|
if len(resolvers) == 0 {
|
|
metricDNSFwdErrorNoUpstream.Add(1)
|
|
f.health.SetUnhealthy(dnsForwarderFailing, health.Args{health.ArgDNSServers: ""})
|
|
f.logf("no upstream resolvers set, returning SERVFAIL")
|
|
|
|
// Attempt to recompile the DNS configuration
|
|
// If we are being asked to forward queries and we have no
|
|
// nameservers, the network is in a bad state.
|
|
if f.missingUpstreamRecovery != nil {
|
|
f.missingUpstreamRecovery()
|
|
}
|
|
|
|
res, err := servfailResponse(query)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return fmt.Errorf("waiting to send SERVFAIL: %w", ctx.Err())
|
|
case responseChan <- res:
|
|
return nil
|
|
}
|
|
} else {
|
|
f.health.SetHealthy(dnsForwarderFailing)
|
|
}
|
|
}
|
|
|
|
fq := &forwardQuery{
|
|
txid: getTxID(query.bs),
|
|
packet: query.bs,
|
|
family: query.family,
|
|
closeOnCtxDone: new(closePool),
|
|
}
|
|
defer fq.closeOnCtxDone.Close()
|
|
|
|
if verboseDNSForward() {
|
|
domainSha256 := sha256.Sum256([]byte(domain))
|
|
domainSig := base64.RawStdEncoding.EncodeToString(domainSha256[:3])
|
|
f.logf("request(%d, %v, %d, %s) %d...", fq.txid, typ, len(domain), domainSig, len(fq.packet))
|
|
}
|
|
|
|
resc := make(chan []byte, 1) // it's fine buffered or not
|
|
errc := make(chan error, 1) // it's fine buffered or not too
|
|
for i := range resolvers {
|
|
go func(rr *resolverAndDelay) {
|
|
if rr.startDelay > 0 {
|
|
timer := time.NewTimer(rr.startDelay)
|
|
select {
|
|
case <-timer.C:
|
|
case <-ctx.Done():
|
|
timer.Stop()
|
|
return
|
|
}
|
|
}
|
|
resb, err := f.send(ctx, fq, *rr)
|
|
if err != nil {
|
|
err = fmt.Errorf("resolving using %q: %w", rr.name.Addr, err)
|
|
select {
|
|
case errc <- err:
|
|
case <-ctx.Done():
|
|
}
|
|
return
|
|
}
|
|
select {
|
|
case resc <- resb:
|
|
case <-ctx.Done():
|
|
}
|
|
}(&resolvers[i])
|
|
}
|
|
|
|
var firstErr error
|
|
var numErr int
|
|
for {
|
|
select {
|
|
case v := <-resc:
|
|
select {
|
|
case <-ctx.Done():
|
|
metricDNSFwdErrorContext.Add(1)
|
|
return fmt.Errorf("waiting to send response: %w", ctx.Err())
|
|
case responseChan <- packet{v, query.family, query.addr}:
|
|
if verboseDNSForward() {
|
|
f.logf("response(%d, %v, %d) = %d, nil", fq.txid, typ, len(domain), len(v))
|
|
}
|
|
metricDNSFwdSuccess.Add(1)
|
|
f.health.SetHealthy(dnsForwarderFailing)
|
|
return nil
|
|
}
|
|
case err := <-errc:
|
|
if firstErr == nil {
|
|
firstErr = err
|
|
}
|
|
numErr++
|
|
if numErr == len(resolvers) {
|
|
if errors.Is(firstErr, errServerFailure) {
|
|
res, err := servfailResponse(query)
|
|
if err != nil {
|
|
f.logf("building servfail response: %v", err)
|
|
return firstErr
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
metricDNSFwdErrorContext.Add(1)
|
|
metricDNSFwdErrorContextGotError.Add(1)
|
|
var resolverAddrs []string
|
|
for _, rr := range resolvers {
|
|
resolverAddrs = append(resolverAddrs, rr.name.Addr)
|
|
}
|
|
f.health.SetUnhealthy(dnsForwarderFailing, health.Args{health.ArgDNSServers: strings.Join(resolverAddrs, ",")})
|
|
case responseChan <- res:
|
|
if verboseDNSForward() {
|
|
f.logf("forwarder response(%d, %v, %d) = %d, %v", fq.txid, typ, len(domain), len(res.bs), firstErr)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
return firstErr
|
|
}
|
|
case <-ctx.Done():
|
|
metricDNSFwdErrorContext.Add(1)
|
|
if firstErr != nil {
|
|
metricDNSFwdErrorContextGotError.Add(1)
|
|
return firstErr
|
|
}
|
|
|
|
// If we haven't got an error or a successful response,
|
|
// include all resolvers in the error message so we can
|
|
// at least see what what servers we're trying to
|
|
// query.
|
|
var resolverAddrs []string
|
|
for _, rr := range resolvers {
|
|
resolverAddrs = append(resolverAddrs, rr.name.Addr)
|
|
}
|
|
f.health.SetUnhealthy(dnsForwarderFailing, health.Args{health.ArgDNSServers: strings.Join(resolverAddrs, ",")})
|
|
return fmt.Errorf("waiting for response or error from %v: %w", resolverAddrs, ctx.Err())
|
|
}
|
|
}
|
|
}
|
|
|
|
var initListenConfig func(_ *net.ListenConfig, _ *netmon.Monitor, tunName string) error
|
|
|
|
// nameFromQuery extracts the normalized query name from bs.
|
|
func nameFromQuery(bs []byte) (dnsname.FQDN, dns.Type, error) {
|
|
var parser dns.Parser
|
|
|
|
hdr, err := parser.Start(bs)
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
if hdr.Response {
|
|
return "", 0, errNotQuery
|
|
}
|
|
|
|
q, err := parser.Question()
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
|
|
n := q.Name.Data[:q.Name.Length]
|
|
fqdn, err := dnsname.ToFQDN(rawNameToLower(n))
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
return fqdn, q.Type, nil
|
|
}
|
|
|
|
// nxDomainResponse returns an NXDomain DNS reply for the provided request.
|
|
func nxDomainResponse(req packet) (res packet, err error) {
|
|
p := dnsParserPool.Get().(*dnsParser)
|
|
defer dnsParserPool.Put(p)
|
|
|
|
if err := p.parseQuery(req.bs); err != nil {
|
|
return packet{}, err
|
|
}
|
|
|
|
h := p.Header
|
|
h.Response = true
|
|
h.RecursionAvailable = h.RecursionDesired
|
|
h.RCode = dns.RCodeNameError
|
|
b := dns.NewBuilder(nil, h)
|
|
// TODO(bradfitz): should we add an SOA record in the Authority
|
|
// section too? (for the nxdomain negative caching TTL)
|
|
// For which zone? Does iOS care?
|
|
b.StartQuestions()
|
|
b.Question(p.Question)
|
|
res.bs, err = b.Finish()
|
|
res.addr = req.addr
|
|
return res, err
|
|
}
|
|
|
|
// servfailResponse returns a SERVFAIL error reply for the provided request.
|
|
func servfailResponse(req packet) (res packet, err error) {
|
|
p := dnsParserPool.Get().(*dnsParser)
|
|
defer dnsParserPool.Put(p)
|
|
|
|
if err := p.parseQuery(req.bs); err != nil {
|
|
return packet{}, err
|
|
}
|
|
|
|
h := p.Header
|
|
h.Response = true
|
|
h.Authoritative = true
|
|
h.RCode = dns.RCodeServerFailure
|
|
b := dns.NewBuilder(nil, h)
|
|
b.StartQuestions()
|
|
b.Question(p.Question)
|
|
res.bs, err = b.Finish()
|
|
res.addr = req.addr
|
|
return res, err
|
|
}
|
|
|
|
// closePool is a dynamic set of io.Closers to close as a group.
|
|
// It's intended to be Closed at most once.
|
|
//
|
|
// The zero value is ready for use.
|
|
type closePool struct {
|
|
mu sync.Mutex
|
|
m map[io.Closer]bool
|
|
closed bool
|
|
}
|
|
|
|
func (p *closePool) Add(c io.Closer) {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
if p.closed {
|
|
c.Close()
|
|
return
|
|
}
|
|
if p.m == nil {
|
|
p.m = map[io.Closer]bool{}
|
|
}
|
|
p.m[c] = true
|
|
}
|
|
|
|
func (p *closePool) Remove(c io.Closer) {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
if p.closed {
|
|
return
|
|
}
|
|
delete(p.m, c)
|
|
}
|
|
|
|
func (p *closePool) Close() error {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
if p.closed {
|
|
return nil
|
|
}
|
|
p.closed = true
|
|
for c := range p.m {
|
|
c.Close()
|
|
}
|
|
return nil
|
|
}
|