mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-29 21:15:39 +00:00
6de6ab015f
Updates tailscale/tailscale#6148 This is the result of some observations we made today with @raggi. The DNS over HTTPS client currently doesn't cap the number of connections it uses, either in-use or idle. A burst of DNS queries will open multiple connections. Idle connections remain open for 30 seconds (this interval is defined in the dohTransportTimeout constant). For DoH providers like NextDNS which send keep-alives, this means the cellular modem will remain up more than expected to send ACKs if any keep-alives are received while a connection remains idle during those 30 seconds. We can set the IdleConnTimeout to 10 seconds to ensure an idle connection is terminated if no other DNS queries come in after 10 seconds. Additionally, we can cap the number of connections to 1. This ensures that at all times there is only one open DoH connection, either active or idle. If idle, it will be terminated within 10 seconds from the last query. We also observed all the DoH providers we support are capable of TLS 1.3. We can force this TLS version to reduce the number of packets sent/received each time a TLS connection is established. Signed-off-by: Andrea Gottardo <andrea@gottardo.me>
1197 lines
35 KiB
Go
1197 lines
35 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
package resolver
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"crypto/tls"
|
|
"encoding/base64"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/netip"
|
|
"net/url"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
dns "golang.org/x/net/dns/dnsmessage"
|
|
"tailscale.com/control/controlknobs"
|
|
"tailscale.com/envknob"
|
|
"tailscale.com/health"
|
|
"tailscale.com/net/dns/publicdns"
|
|
"tailscale.com/net/dnscache"
|
|
"tailscale.com/net/neterror"
|
|
"tailscale.com/net/netmon"
|
|
"tailscale.com/net/sockstats"
|
|
"tailscale.com/net/tsdial"
|
|
"tailscale.com/types/dnstype"
|
|
"tailscale.com/types/logger"
|
|
"tailscale.com/types/nettype"
|
|
"tailscale.com/util/cloudenv"
|
|
"tailscale.com/util/dnsname"
|
|
"tailscale.com/util/race"
|
|
"tailscale.com/version"
|
|
)
|
|
|
|
// headerBytes is the number of bytes in a DNS message header.
|
|
const headerBytes = 12
|
|
|
|
// dnsFlagTruncated is set in the flags word when the packet is truncated.
|
|
const dnsFlagTruncated = 0x200
|
|
|
|
// truncatedFlagSet returns true if the DNS packet signals that it has
|
|
// been truncated. False is also returned if the packet was too small
|
|
// to be valid.
|
|
func truncatedFlagSet(pkt []byte) bool {
|
|
if len(pkt) < headerBytes {
|
|
return false
|
|
}
|
|
return (binary.BigEndian.Uint16(pkt[2:4]) & dnsFlagTruncated) != 0
|
|
}
|
|
|
|
const (
|
|
// dohIdleConnTimeout is how long to keep idle HTTP connections
|
|
// open to DNS-over-HTTPS servers. 10 seconds is a sensible
|
|
// default, as it's long enough to handle a burst of queries
|
|
// coming in a row, but short enough to not keep idle connections
|
|
// open for too long. In theory, idle connections could be kept
|
|
// open for a long time without any battery impact as no traffic
|
|
// is supposed to be flowing on them.
|
|
// However, in practice, DoH servers will send TCP keepalives (e.g.
|
|
// NextDNS sends them every ~10s). Handling these keepalives
|
|
// wakes up the modem, and that uses battery. Therefore, we keep
|
|
// the idle timeout low enough to allow idle connections to be
|
|
// closed during an extended period with no DNS queries, killing
|
|
// keepalive network activity.
|
|
dohIdleConnTimeout = 10 * time.Second
|
|
|
|
// dohTransportTimeout is how much of a head start to give a DoH query
|
|
// that was upgraded from a well-known public DNS provider's IP before
|
|
// normal UDP mode is attempted as a fallback.
|
|
dohHeadStart = 500 * time.Millisecond
|
|
|
|
// wellKnownHostBackupDelay is how long to artificially delay upstream
|
|
// DNS queries to the "fallback" DNS server IP for a known provider
|
|
// (e.g. how long to wait to query Google's 8.8.4.4 after 8.8.8.8).
|
|
wellKnownHostBackupDelay = 200 * time.Millisecond
|
|
|
|
// udpRaceTimeout is the timeout after which we will start a DNS query
|
|
// over TCP while waiting for the UDP query to complete.
|
|
udpRaceTimeout = 2 * time.Second
|
|
|
|
// tcpQueryTimeout is the timeout for a DNS query performed over TCP.
|
|
// It matches the default 5sec timeout of the 'dig' utility.
|
|
tcpQueryTimeout = 5 * time.Second
|
|
)
|
|
|
|
// txid identifies a DNS transaction.
|
|
//
|
|
// As the standard DNS Request ID is only 16 bits, we extend it:
|
|
// the lower 32 bits are the zero-extended bits of the DNS Request ID;
|
|
// the upper 32 bits are the CRC32 checksum of the first question in the request.
|
|
// This makes probability of txid collision negligible.
|
|
type txid uint64
|
|
|
|
// getTxID computes the txid of the given DNS packet.
|
|
func getTxID(packet []byte) txid {
|
|
if len(packet) < headerBytes {
|
|
return 0
|
|
}
|
|
|
|
dnsid := binary.BigEndian.Uint16(packet[0:2])
|
|
// Previously, we hashed the question and combined it with the original txid
|
|
// which was useful when concurrent queries were multiplexed on a single
|
|
// local source port. We encountered some situations where the DNS server
|
|
// canonicalizes the question in the response (uppercase converted to
|
|
// lowercase in this case), which resulted in responses that we couldn't
|
|
// match to the original request due to hash mismatches.
|
|
return txid(dnsid)
|
|
}
|
|
|
|
func getRCode(packet []byte) dns.RCode {
|
|
if len(packet) < headerBytes {
|
|
// treat invalid packets as a refusal
|
|
return dns.RCode(5)
|
|
}
|
|
// get bottom 4 bits of 3rd byte
|
|
return dns.RCode(packet[3] & 0x0F)
|
|
}
|
|
|
|
// clampEDNSSize attempts to limit the maximum EDNS response size. This is not
|
|
// an exhaustive solution, instead only easy cases are currently handled in the
|
|
// interest of speed and reduced complexity. Only OPT records at the very end of
|
|
// the message with no option codes are addressed.
|
|
// TODO: handle more situations if we discover that they happen often
|
|
func clampEDNSSize(packet []byte, maxSize uint16) {
|
|
// optFixedBytes is the size of an OPT record with no option codes.
|
|
const optFixedBytes = 11
|
|
const edns0Version = 0
|
|
|
|
if len(packet) < headerBytes+optFixedBytes {
|
|
return
|
|
}
|
|
|
|
arCount := binary.BigEndian.Uint16(packet[10:12])
|
|
if arCount == 0 {
|
|
// OPT shows up in an AR, so there must be no OPT
|
|
return
|
|
}
|
|
|
|
// https://datatracker.ietf.org/doc/html/rfc6891#section-6.1.2
|
|
opt := packet[len(packet)-optFixedBytes:]
|
|
|
|
if opt[0] != 0 {
|
|
// OPT NAME must be 0 (root domain)
|
|
return
|
|
}
|
|
if dns.Type(binary.BigEndian.Uint16(opt[1:3])) != dns.TypeOPT {
|
|
// Not an OPT record
|
|
return
|
|
}
|
|
requestedSize := binary.BigEndian.Uint16(opt[3:5])
|
|
// Ignore extended RCODE in opt[5]
|
|
if opt[6] != edns0Version {
|
|
// Be conservative and don't touch unknown versions.
|
|
return
|
|
}
|
|
// Ignore flags in opt[6:9]
|
|
if binary.BigEndian.Uint16(opt[9:11]) != 0 {
|
|
// RDLEN must be 0 (no variable length data). We're at the end of the
|
|
// packet so this should be 0 anyway)..
|
|
return
|
|
}
|
|
|
|
if requestedSize <= maxSize {
|
|
return
|
|
}
|
|
|
|
// Clamp the maximum size
|
|
binary.BigEndian.PutUint16(opt[3:5], maxSize)
|
|
}
|
|
|
|
// dnsForwarderFailing should be raised when the forwarder is unable to reach the
|
|
// upstream resolvers. This is a high severity warning as it results in "no internet".
|
|
// This warning must be cleared when the forwarder is working again.
|
|
//
|
|
// We allow for 5 second grace period to ensure this is not raised for spurious errors
|
|
// under the assumption that DNS queries are relatively frequent and a subsequent
|
|
// successful query will clear any one-off errors.
|
|
var dnsForwarderFailing = health.Register(&health.Warnable{
|
|
Code: "dns-forward-failing",
|
|
Title: "DNS unavailable",
|
|
Severity: health.SeverityMedium,
|
|
DependsOn: []*health.Warnable{health.NetworkStatusWarnable},
|
|
Text: health.StaticMessage("Tailscale can't reach the configured DNS servers. Internet connectivity may be affected."),
|
|
ImpactsConnectivity: true,
|
|
TimeToVisible: 15 * time.Second,
|
|
})
|
|
|
|
type route struct {
|
|
Suffix dnsname.FQDN
|
|
Resolvers []resolverAndDelay
|
|
}
|
|
|
|
// resolverAndDelay is an upstream DNS resolver and a delay for how
|
|
// long to wait before querying it.
|
|
type resolverAndDelay struct {
|
|
// name is the upstream resolver.
|
|
name *dnstype.Resolver
|
|
|
|
// startDelay is an amount to delay this resolver at
|
|
// start. It's used when, say, there are four Google or
|
|
// Cloudflare DNS IPs (two IPv4 + two IPv6) and we don't want
|
|
// to race all four at once.
|
|
startDelay time.Duration
|
|
}
|
|
|
|
// forwarder forwards DNS packets to a number of upstream nameservers.
|
|
type forwarder struct {
|
|
logf logger.Logf
|
|
netMon *netmon.Monitor // always non-nil
|
|
linkSel ForwardLinkSelector // TODO(bradfitz): remove this when tsdial.Dialer absorbs it
|
|
dialer *tsdial.Dialer
|
|
health *health.Tracker // always non-nil
|
|
|
|
controlKnobs *controlknobs.Knobs // or nil
|
|
|
|
ctx context.Context // good until Close
|
|
ctxCancel context.CancelFunc // closes ctx
|
|
|
|
mu sync.Mutex // guards following
|
|
|
|
dohClient map[string]*http.Client // urlBase -> client
|
|
|
|
// routes are per-suffix resolvers to use, with
|
|
// the most specific routes first.
|
|
routes []route
|
|
// cloudHostFallback are last resort resolvers to use if no per-suffix
|
|
// resolver matches. These are only populated on cloud hosts where the
|
|
// platform provides a well-known recursive resolver.
|
|
//
|
|
// That is, if we're running on GCP or AWS where there's always a well-known
|
|
// IP of a recursive resolver, return that rather than having callers return
|
|
// SERVFAIL. This fixes both normal 100.100.100.100 resolution when
|
|
// /etc/resolv.conf is missing/corrupt, and the peerapi ExitDNS stub
|
|
// resolver lookup.
|
|
cloudHostFallback []resolverAndDelay
|
|
|
|
// missingUpstreamRecovery, if non-nil, is set called when a SERVFAIL is
|
|
// returned due to missing upstream resolvers.
|
|
//
|
|
// This should attempt to properly (re)set the upstream resolvers.
|
|
missingUpstreamRecovery func()
|
|
}
|
|
|
|
func newForwarder(logf logger.Logf, netMon *netmon.Monitor, linkSel ForwardLinkSelector, dialer *tsdial.Dialer, health *health.Tracker, knobs *controlknobs.Knobs) *forwarder {
|
|
if netMon == nil {
|
|
panic("nil netMon")
|
|
}
|
|
f := &forwarder{
|
|
logf: logger.WithPrefix(logf, "forward: "),
|
|
netMon: netMon,
|
|
linkSel: linkSel,
|
|
dialer: dialer,
|
|
health: health,
|
|
controlKnobs: knobs,
|
|
missingUpstreamRecovery: func() {},
|
|
}
|
|
f.ctx, f.ctxCancel = context.WithCancel(context.Background())
|
|
return f
|
|
}
|
|
|
|
func (f *forwarder) Close() error {
|
|
f.ctxCancel()
|
|
return nil
|
|
}
|
|
|
|
// resolversWithDelays maps from a set of DNS server names to a slice of a type
|
|
// that included a startDelay, upgrading any well-known DoH (DNS-over-HTTP)
|
|
// servers in the process, insert a DoH lookup first before UDP fallbacks.
|
|
func resolversWithDelays(resolvers []*dnstype.Resolver) []resolverAndDelay {
|
|
rr := make([]resolverAndDelay, 0, len(resolvers)+2)
|
|
|
|
type dohState uint8
|
|
const addedDoH = dohState(1)
|
|
const addedDoHAndDontAddUDP = dohState(2)
|
|
|
|
// Add the known DoH ones first, starting immediately.
|
|
didDoH := map[string]dohState{}
|
|
for _, r := range resolvers {
|
|
ipp, ok := r.IPPort()
|
|
if !ok {
|
|
continue
|
|
}
|
|
dohBase, dohOnly, ok := publicdns.DoHEndpointFromIP(ipp.Addr())
|
|
if !ok || didDoH[dohBase] != 0 {
|
|
continue
|
|
}
|
|
if dohOnly {
|
|
didDoH[dohBase] = addedDoHAndDontAddUDP
|
|
} else {
|
|
didDoH[dohBase] = addedDoH
|
|
}
|
|
rr = append(rr, resolverAndDelay{name: &dnstype.Resolver{Addr: dohBase}})
|
|
}
|
|
|
|
type hostAndFam struct {
|
|
host string // some arbitrary string representing DNS host (currently the DoH base)
|
|
bits uint8 // either 32 or 128 for IPv4 vs IPv6s address family
|
|
}
|
|
done := map[hostAndFam]int{}
|
|
for _, r := range resolvers {
|
|
ipp, ok := r.IPPort()
|
|
if !ok {
|
|
// Pass non-IP ones through unchanged, without delay.
|
|
// (e.g. DNS-over-ExitDNS when using an exit node)
|
|
rr = append(rr, resolverAndDelay{name: r})
|
|
continue
|
|
}
|
|
ip := ipp.Addr()
|
|
var startDelay time.Duration
|
|
if host, _, ok := publicdns.DoHEndpointFromIP(ip); ok {
|
|
if didDoH[host] == addedDoHAndDontAddUDP {
|
|
continue
|
|
}
|
|
// We already did the DoH query early. These
|
|
// are for normal dns53 UDP queries.
|
|
startDelay = dohHeadStart
|
|
key := hostAndFam{host, uint8(ip.BitLen())}
|
|
if done[key] > 0 {
|
|
startDelay += wellKnownHostBackupDelay
|
|
}
|
|
done[key]++
|
|
}
|
|
rr = append(rr, resolverAndDelay{
|
|
name: r,
|
|
startDelay: startDelay,
|
|
})
|
|
}
|
|
return rr
|
|
}
|
|
|
|
var (
|
|
cloudResolversOnce sync.Once
|
|
cloudResolversLazy []resolverAndDelay
|
|
)
|
|
|
|
func cloudResolvers() []resolverAndDelay {
|
|
cloudResolversOnce.Do(func() {
|
|
if ip := cloudenv.Get().ResolverIP(); ip != "" {
|
|
cloudResolver := []*dnstype.Resolver{{Addr: ip}}
|
|
cloudResolversLazy = resolversWithDelays(cloudResolver)
|
|
}
|
|
})
|
|
return cloudResolversLazy
|
|
}
|
|
|
|
// setRoutes sets the routes to use for DNS forwarding. It's called by
|
|
// Resolver.SetConfig on reconfig.
|
|
//
|
|
// The memory referenced by routesBySuffix should not be modified.
|
|
func (f *forwarder) setRoutes(routesBySuffix map[dnsname.FQDN][]*dnstype.Resolver) {
|
|
routes := make([]route, 0, len(routesBySuffix))
|
|
|
|
cloudHostFallback := cloudResolvers()
|
|
for suffix, rs := range routesBySuffix {
|
|
if suffix == "." && len(rs) == 0 && len(cloudHostFallback) > 0 {
|
|
routes = append(routes, route{
|
|
Suffix: suffix,
|
|
Resolvers: cloudHostFallback,
|
|
})
|
|
} else {
|
|
routes = append(routes, route{
|
|
Suffix: suffix,
|
|
Resolvers: resolversWithDelays(rs),
|
|
})
|
|
}
|
|
}
|
|
|
|
if cloudenv.Get().HasInternalTLD() && len(cloudHostFallback) > 0 {
|
|
if _, ok := routesBySuffix["internal."]; !ok {
|
|
routes = append(routes, route{
|
|
Suffix: "internal.",
|
|
Resolvers: cloudHostFallback,
|
|
})
|
|
}
|
|
}
|
|
|
|
// Sort from longest prefix to shortest.
|
|
sort.Slice(routes, func(i, j int) bool {
|
|
return routes[i].Suffix.NumLabels() > routes[j].Suffix.NumLabels()
|
|
})
|
|
|
|
f.mu.Lock()
|
|
defer f.mu.Unlock()
|
|
f.routes = routes
|
|
f.cloudHostFallback = cloudHostFallback
|
|
}
|
|
|
|
var stdNetPacketListener nettype.PacketListenerWithNetIP = nettype.MakePacketListenerWithNetIP(new(net.ListenConfig))
|
|
|
|
func (f *forwarder) packetListener(ip netip.Addr) (nettype.PacketListenerWithNetIP, error) {
|
|
if f.linkSel == nil || initListenConfig == nil {
|
|
return stdNetPacketListener, nil
|
|
}
|
|
linkName := f.linkSel.PickLink(ip)
|
|
if linkName == "" {
|
|
return stdNetPacketListener, nil
|
|
}
|
|
lc := new(net.ListenConfig)
|
|
if err := initListenConfig(lc, f.netMon, linkName); err != nil {
|
|
return nil, err
|
|
}
|
|
return nettype.MakePacketListenerWithNetIP(lc), nil
|
|
}
|
|
|
|
// getKnownDoHClientForProvider returns an HTTP client for a specific DoH
|
|
// provider named by its DoH base URL (like "https://dns.google/dns-query").
|
|
//
|
|
// The returned client race/Happy Eyeballs dials all IPs for urlBase (usually
|
|
// 4), as statically known by the publicdns package.
|
|
func (f *forwarder) getKnownDoHClientForProvider(urlBase string) (c *http.Client, ok bool) {
|
|
f.mu.Lock()
|
|
defer f.mu.Unlock()
|
|
if c, ok := f.dohClient[urlBase]; ok {
|
|
return c, true
|
|
}
|
|
allIPs := publicdns.DoHIPsOfBase(urlBase)
|
|
if len(allIPs) == 0 {
|
|
return nil, false
|
|
}
|
|
dohURL, err := url.Parse(urlBase)
|
|
if err != nil {
|
|
return nil, false
|
|
}
|
|
|
|
dialer := dnscache.Dialer(f.getDialerType(), &dnscache.Resolver{
|
|
SingleHost: dohURL.Hostname(),
|
|
SingleHostStaticResult: allIPs,
|
|
Logf: f.logf,
|
|
})
|
|
tlsConfig := &tls.Config{
|
|
// Enforce TLS 1.3, as all of our supported DNS-over-HTTPS servers are compatible with it
|
|
// (see tailscale.com/net/dns/publicdns/publicdns.go).
|
|
MinVersion: tls.VersionTLS13,
|
|
}
|
|
c = &http.Client{
|
|
Transport: &http.Transport{
|
|
ForceAttemptHTTP2: true,
|
|
IdleConnTimeout: dohIdleConnTimeout,
|
|
// On mobile platforms TCP KeepAlive is disabled in the dialer,
|
|
// ensure that we timeout if the connection appears to be hung.
|
|
ResponseHeaderTimeout: 10 * time.Second,
|
|
MaxIdleConnsPerHost: 1,
|
|
DialContext: func(ctx context.Context, netw, addr string) (net.Conn, error) {
|
|
if !strings.HasPrefix(netw, "tcp") {
|
|
return nil, fmt.Errorf("unexpected network %q", netw)
|
|
}
|
|
return dialer(ctx, netw, addr)
|
|
},
|
|
TLSClientConfig: tlsConfig,
|
|
},
|
|
}
|
|
if f.dohClient == nil {
|
|
f.dohClient = map[string]*http.Client{}
|
|
}
|
|
f.dohClient[urlBase] = c
|
|
return c, true
|
|
}
|
|
|
|
const dohType = "application/dns-message"
|
|
|
|
func (f *forwarder) sendDoH(ctx context.Context, urlBase string, c *http.Client, packet []byte) ([]byte, error) {
|
|
ctx = sockstats.WithSockStats(ctx, sockstats.LabelDNSForwarderDoH, f.logf)
|
|
metricDNSFwdDoH.Add(1)
|
|
req, err := http.NewRequestWithContext(ctx, "POST", urlBase, bytes.NewReader(packet))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Content-Type", dohType)
|
|
req.Header.Set("Accept", dohType)
|
|
req.Header.Set("User-Agent", "tailscaled/"+version.Long())
|
|
|
|
hres, err := c.Do(req)
|
|
if err != nil {
|
|
metricDNSFwdDoHErrorTransport.Add(1)
|
|
return nil, err
|
|
}
|
|
defer hres.Body.Close()
|
|
if hres.StatusCode != 200 {
|
|
metricDNSFwdDoHErrorStatus.Add(1)
|
|
return nil, errors.New(hres.Status)
|
|
}
|
|
if ct := hres.Header.Get("Content-Type"); ct != dohType {
|
|
metricDNSFwdDoHErrorCT.Add(1)
|
|
return nil, fmt.Errorf("unexpected response Content-Type %q", ct)
|
|
}
|
|
res, err := io.ReadAll(hres.Body)
|
|
if err != nil {
|
|
metricDNSFwdDoHErrorBody.Add(1)
|
|
}
|
|
if truncatedFlagSet(res) {
|
|
metricDNSFwdTruncated.Add(1)
|
|
}
|
|
return res, err
|
|
}
|
|
|
|
var (
|
|
verboseDNSForward = envknob.RegisterBool("TS_DEBUG_DNS_FORWARD_SEND")
|
|
skipTCPRetry = envknob.RegisterBool("TS_DNS_FORWARD_SKIP_TCP_RETRY")
|
|
|
|
// For correlating log messages in the send() function; only used when
|
|
// verboseDNSForward() is true.
|
|
forwarderCount atomic.Uint64
|
|
)
|
|
|
|
// send sends packet to dst. It is best effort.
|
|
//
|
|
// send expects the reply to have the same txid as txidOut.
|
|
func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDelay) (ret []byte, err error) {
|
|
if verboseDNSForward() {
|
|
id := forwarderCount.Add(1)
|
|
domain, typ, _ := nameFromQuery(fq.packet)
|
|
f.logf("forwarder.send(%q, %d, %v, %d) [%d] ...", rr.name.Addr, fq.txid, typ, len(domain), id)
|
|
defer func() {
|
|
f.logf("forwarder.send(%q, %d, %v, %d) [%d] = %v, %v", rr.name.Addr, fq.txid, typ, len(domain), id, len(ret), err)
|
|
}()
|
|
}
|
|
if strings.HasPrefix(rr.name.Addr, "http://") {
|
|
return f.sendDoH(ctx, rr.name.Addr, f.dialer.PeerAPIHTTPClient(), fq.packet)
|
|
}
|
|
if strings.HasPrefix(rr.name.Addr, "https://") {
|
|
// Only known DoH providers are supported currently. Specifically, we
|
|
// only support DoH providers where we can TCP connect to them on port
|
|
// 443 at the same IP address they serve normal UDP DNS from (1.1.1.1,
|
|
// 8.8.8.8, 9.9.9.9, etc.) That's why OpenDNS and custom DoH providers
|
|
// aren't currently supported. There's no backup DNS resolution path for
|
|
// them.
|
|
urlBase := rr.name.Addr
|
|
if hc, ok := f.getKnownDoHClientForProvider(urlBase); ok {
|
|
return f.sendDoH(ctx, urlBase, hc, fq.packet)
|
|
}
|
|
metricDNSFwdErrorType.Add(1)
|
|
return nil, fmt.Errorf("arbitrary https:// resolvers not supported yet")
|
|
}
|
|
if strings.HasPrefix(rr.name.Addr, "tls://") {
|
|
metricDNSFwdErrorType.Add(1)
|
|
return nil, fmt.Errorf("tls:// resolvers not supported yet")
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
isUDPQuery := fq.family == "udp"
|
|
skipTCP := skipTCPRetry() || (f.controlKnobs != nil && f.controlKnobs.DisableDNSForwarderTCPRetries.Load())
|
|
|
|
// Print logs about retries if this was because of a truncated response.
|
|
var explicitRetry atomic.Bool // true if truncated UDP response retried
|
|
defer func() {
|
|
if !explicitRetry.Load() {
|
|
return
|
|
}
|
|
if err == nil {
|
|
f.logf("forwarder.send(%q): successfully retried via TCP", rr.name.Addr)
|
|
} else {
|
|
f.logf("forwarder.send(%q): could not retry via TCP: %v", rr.name.Addr, err)
|
|
}
|
|
}()
|
|
|
|
firstUDP := func(ctx context.Context) ([]byte, error) {
|
|
resp, err := f.sendUDP(ctx, fq, rr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !truncatedFlagSet(resp) {
|
|
// Successful, non-truncated response; no retry.
|
|
return resp, nil
|
|
}
|
|
|
|
// If this is a UDP query, return it regardless of whether the
|
|
// response is truncated or not; the client can retry
|
|
// communicating with tailscaled over TCP. There's no point
|
|
// falling back to TCP for a truncated query if we can't return
|
|
// the results to the client.
|
|
if isUDPQuery {
|
|
return resp, nil
|
|
}
|
|
|
|
if skipTCP {
|
|
// Envknob or control knob disabled the TCP retry behaviour;
|
|
// just return what we have.
|
|
return resp, nil
|
|
}
|
|
|
|
// This is a TCP query from the client, and the UDP response
|
|
// from the upstream DNS server is truncated; map this to an
|
|
// error to cause our retry helper to immediately kick off the
|
|
// TCP retry.
|
|
explicitRetry.Store(true)
|
|
return nil, truncatedResponseError{resp}
|
|
}
|
|
thenTCP := func(ctx context.Context) ([]byte, error) {
|
|
// If we're skipping the TCP fallback, then wait until the
|
|
// context is canceled and return that error (i.e. not
|
|
// returning anything).
|
|
if skipTCP {
|
|
<-ctx.Done()
|
|
return nil, ctx.Err()
|
|
}
|
|
|
|
return f.sendTCP(ctx, fq, rr)
|
|
}
|
|
|
|
// If the input query is TCP, then don't have a timeout between
|
|
// starting UDP and TCP.
|
|
timeout := udpRaceTimeout
|
|
if !isUDPQuery {
|
|
timeout = 0
|
|
}
|
|
|
|
// Kick off the race between the UDP and TCP queries.
|
|
rh := race.New(timeout, firstUDP, thenTCP)
|
|
resp, err := rh.Start(ctx)
|
|
if err == nil {
|
|
return resp, nil
|
|
}
|
|
|
|
// If we got a truncated UDP response, return that instead of an error.
|
|
var trErr truncatedResponseError
|
|
if errors.As(err, &trErr) {
|
|
return trErr.res, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
type truncatedResponseError struct {
|
|
res []byte
|
|
}
|
|
|
|
func (tr truncatedResponseError) Error() string { return "response truncated" }
|
|
|
|
var errServerFailure = errors.New("response code indicates server issue")
|
|
var errTxIDMismatch = errors.New("txid doesn't match")
|
|
|
|
func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAndDelay) (ret []byte, err error) {
|
|
ipp, ok := rr.name.IPPort()
|
|
if !ok {
|
|
metricDNSFwdErrorType.Add(1)
|
|
return nil, fmt.Errorf("unrecognized resolver type %q", rr.name.Addr)
|
|
}
|
|
metricDNSFwdUDP.Add(1)
|
|
ctx = sockstats.WithSockStats(ctx, sockstats.LabelDNSForwarderUDP, f.logf)
|
|
|
|
ln, err := f.packetListener(ipp.Addr())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Specify the exact UDP family to work around https://github.com/golang/go/issues/52264
|
|
udpFam := "udp4"
|
|
if ipp.Addr().Is6() {
|
|
udpFam = "udp6"
|
|
}
|
|
conn, err := ln.ListenPacket(ctx, udpFam, ":0")
|
|
if err != nil {
|
|
f.logf("ListenPacket failed: %v", err)
|
|
return nil, err
|
|
}
|
|
defer conn.Close()
|
|
|
|
fq.closeOnCtxDone.Add(conn)
|
|
defer fq.closeOnCtxDone.Remove(conn)
|
|
|
|
if _, err := conn.WriteToUDPAddrPort(fq.packet, ipp); err != nil {
|
|
metricDNSFwdUDPErrorWrite.Add(1)
|
|
if err := ctx.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return nil, err
|
|
}
|
|
metricDNSFwdUDPWrote.Add(1)
|
|
|
|
// The 1 extra byte is to detect packet truncation.
|
|
out := make([]byte, maxResponseBytes+1)
|
|
n, _, err := conn.ReadFromUDPAddrPort(out)
|
|
if err != nil {
|
|
if err := ctx.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
if neterror.PacketWasTruncated(err) {
|
|
err = nil
|
|
} else {
|
|
metricDNSFwdUDPErrorRead.Add(1)
|
|
return nil, err
|
|
}
|
|
}
|
|
truncated := n > maxResponseBytes
|
|
if truncated {
|
|
n = maxResponseBytes
|
|
}
|
|
if n < headerBytes {
|
|
f.logf("recv: packet too small (%d bytes)", n)
|
|
}
|
|
out = out[:n]
|
|
txid := getTxID(out)
|
|
if txid != fq.txid {
|
|
metricDNSFwdUDPErrorTxID.Add(1)
|
|
return nil, errTxIDMismatch
|
|
}
|
|
rcode := getRCode(out)
|
|
// don't forward transient errors back to the client when the server fails
|
|
if rcode == dns.RCodeServerFailure {
|
|
f.logf("recv: response code indicating server failure: %d", rcode)
|
|
metricDNSFwdUDPErrorServer.Add(1)
|
|
return nil, errServerFailure
|
|
}
|
|
|
|
if truncated {
|
|
// Set the truncated bit if it wasn't already.
|
|
flags := binary.BigEndian.Uint16(out[2:4])
|
|
flags |= dnsFlagTruncated
|
|
binary.BigEndian.PutUint16(out[2:4], flags)
|
|
|
|
// TODO(#2067): Remove any incomplete records? RFC 1035 section 6.2
|
|
// states that truncation should head drop so that the authority
|
|
// section can be preserved if possible. However, the UDP read with
|
|
// a too-small buffer has already dropped the end, so that's the
|
|
// best we can do.
|
|
}
|
|
|
|
if truncatedFlagSet(out) {
|
|
metricDNSFwdTruncated.Add(1)
|
|
}
|
|
|
|
clampEDNSSize(out, maxResponseBytes)
|
|
metricDNSFwdUDPSuccess.Add(1)
|
|
return out, nil
|
|
}
|
|
|
|
func (f *forwarder) getDialerType() dnscache.DialContextFunc {
|
|
if f.controlKnobs != nil && f.controlKnobs.UserDialUseRoutes.Load() {
|
|
// It is safe to use UserDial as it dials external servers without going through Tailscale
|
|
// and closes connections on interface change in the same way as SystemDial does,
|
|
// thus preventing DNS resolution issues when switching between WiFi and cellular,
|
|
// but can also dial an internal DNS server on the Tailnet or via a subnet router.
|
|
//
|
|
// TODO(nickkhyl): Update tsdial.Dialer to reuse the bart.Table we create in net/tstun.Wrapper
|
|
// to avoid having two bart tables in memory, especially on iOS. Once that's done,
|
|
// we can get rid of the nodeAttr/control knob and always use UserDial for DNS.
|
|
//
|
|
// See https://github.com/tailscale/tailscale/issues/12027.
|
|
return f.dialer.UserDial
|
|
}
|
|
return f.dialer.SystemDial
|
|
}
|
|
|
|
func (f *forwarder) sendTCP(ctx context.Context, fq *forwardQuery, rr resolverAndDelay) (ret []byte, err error) {
|
|
ipp, ok := rr.name.IPPort()
|
|
if !ok {
|
|
metricDNSFwdErrorType.Add(1)
|
|
return nil, fmt.Errorf("unrecognized resolver type %q", rr.name.Addr)
|
|
}
|
|
metricDNSFwdTCP.Add(1)
|
|
ctx = sockstats.WithSockStats(ctx, sockstats.LabelDNSForwarderTCP, f.logf)
|
|
|
|
// Specify the exact family to work around https://github.com/golang/go/issues/52264
|
|
tcpFam := "tcp4"
|
|
if ipp.Addr().Is6() {
|
|
tcpFam = "tcp6"
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, tcpQueryTimeout)
|
|
defer cancel()
|
|
|
|
conn, err := f.getDialerType()(ctx, tcpFam, ipp.String())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer conn.Close()
|
|
|
|
fq.closeOnCtxDone.Add(conn)
|
|
defer fq.closeOnCtxDone.Remove(conn)
|
|
|
|
ctxOrErr := func(err2 error) ([]byte, error) {
|
|
if err := ctx.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
return nil, err2
|
|
}
|
|
|
|
// Write the query to the server.
|
|
query := make([]byte, len(fq.packet)+2)
|
|
binary.BigEndian.PutUint16(query, uint16(len(fq.packet)))
|
|
copy(query[2:], fq.packet)
|
|
if _, err := conn.Write(query); err != nil {
|
|
metricDNSFwdTCPErrorWrite.Add(1)
|
|
return ctxOrErr(err)
|
|
}
|
|
|
|
metricDNSFwdTCPWrote.Add(1)
|
|
|
|
// Read the header length back from the server
|
|
var length uint16
|
|
if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
|
|
metricDNSFwdTCPErrorRead.Add(1)
|
|
return ctxOrErr(err)
|
|
}
|
|
|
|
// Now read the response
|
|
out := make([]byte, length)
|
|
n, err := io.ReadFull(conn, out)
|
|
if err != nil {
|
|
metricDNSFwdTCPErrorRead.Add(1)
|
|
return ctxOrErr(err)
|
|
}
|
|
|
|
if n < int(length) {
|
|
f.logf("sendTCP: packet too small (%d bytes)", n)
|
|
return nil, io.ErrUnexpectedEOF
|
|
}
|
|
out = out[:n]
|
|
txid := getTxID(out)
|
|
if txid != fq.txid {
|
|
metricDNSFwdTCPErrorTxID.Add(1)
|
|
return nil, errTxIDMismatch
|
|
}
|
|
|
|
rcode := getRCode(out)
|
|
|
|
// don't forward transient errors back to the client when the server fails
|
|
if rcode == dns.RCodeServerFailure {
|
|
f.logf("sendTCP: response code indicating server failure: %d", rcode)
|
|
metricDNSFwdTCPErrorServer.Add(1)
|
|
return nil, errServerFailure
|
|
}
|
|
|
|
// TODO(andrew): do we need to do this?
|
|
//clampEDNSSize(out, maxResponseBytes)
|
|
metricDNSFwdTCPSuccess.Add(1)
|
|
return out, nil
|
|
}
|
|
|
|
// resolvers returns the resolvers to use for domain.
|
|
func (f *forwarder) resolvers(domain dnsname.FQDN) []resolverAndDelay {
|
|
f.mu.Lock()
|
|
routes := f.routes
|
|
cloudHostFallback := f.cloudHostFallback
|
|
f.mu.Unlock()
|
|
for _, route := range routes {
|
|
if route.Suffix == "." || route.Suffix.Contains(domain) {
|
|
return route.Resolvers
|
|
}
|
|
}
|
|
return cloudHostFallback // or nil if no fallback
|
|
}
|
|
|
|
// GetUpstreamResolvers returns the resolvers that would be used to resolve
|
|
// the given FQDN.
|
|
func (f *forwarder) GetUpstreamResolvers(name dnsname.FQDN) []*dnstype.Resolver {
|
|
resolvers := f.resolvers(name)
|
|
upstreamResolvers := make([]*dnstype.Resolver, 0, len(resolvers))
|
|
for _, r := range resolvers {
|
|
upstreamResolvers = append(upstreamResolvers, r.name)
|
|
}
|
|
return upstreamResolvers
|
|
}
|
|
|
|
// forwardQuery is information and state about a forwarded DNS query that's
|
|
// being sent to 1 or more upstreams.
|
|
//
|
|
// In the case of racing against multiple equivalent upstreams
|
|
// (e.g. Google or CloudFlare's 4 DNS IPs: 2 IPv4 + 2 IPv6), this type
|
|
// handles racing them more intelligently than just blasting away 4
|
|
// queries at once.
|
|
type forwardQuery struct {
|
|
txid txid
|
|
packet []byte
|
|
family string // "tcp" or "udp"
|
|
|
|
// closeOnCtxDone lets send register values to Close if the
|
|
// caller's ctx expires. This avoids send from allocating its
|
|
// own waiting goroutine to interrupt the ReadFrom, as memory
|
|
// is tight on iOS and we want the number of pending DNS
|
|
// lookups to be bursty without too much associated
|
|
// goroutine/memory cost.
|
|
closeOnCtxDone *closePool
|
|
|
|
// TODO(bradfitz): add race delay state:
|
|
// mu sync.Mutex
|
|
// ...
|
|
}
|
|
|
|
// forwardWithDestChan forwards the query to all upstream nameservers
|
|
// and waits for the first response.
|
|
//
|
|
// It either sends to responseChan and returns nil, or returns a
|
|
// non-nil error (without sending to the channel).
|
|
//
|
|
// If resolvers is non-empty, it's used explicitly (notably, for exit
|
|
// node DNS proxy queries), otherwise f.resolvers is used.
|
|
func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, responseChan chan<- packet, resolvers ...resolverAndDelay) error {
|
|
metricDNSFwd.Add(1)
|
|
domain, typ, err := nameFromQuery(query.bs)
|
|
if err != nil {
|
|
metricDNSFwdErrorName.Add(1)
|
|
return err
|
|
}
|
|
|
|
// Guarantee that the ctx we use below is done when this function returns.
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
// Drop DNS service discovery spam, primarily for battery life
|
|
// on mobile. Things like Spotify on iOS generate this traffic,
|
|
// when browsing for LAN devices. But even when filtering this
|
|
// out, playing on Sonos still works.
|
|
if hasRDNSBonjourPrefix(domain) {
|
|
metricDNSFwdDropBonjour.Add(1)
|
|
res, err := nxDomainResponse(query)
|
|
if err != nil {
|
|
f.logf("error parsing bonjour query: %v", err)
|
|
// Returning an error will cause an internal retry, there is
|
|
// nothing we can do if parsing failed. Just drop the packet.
|
|
return nil
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return fmt.Errorf("waiting to send NXDOMAIN: %w", ctx.Err())
|
|
case responseChan <- res:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
if fl := fwdLogAtomic.Load(); fl != nil {
|
|
fl.addName(string(domain))
|
|
}
|
|
|
|
clampEDNSSize(query.bs, maxResponseBytes)
|
|
|
|
if len(resolvers) == 0 {
|
|
resolvers = f.resolvers(domain)
|
|
if len(resolvers) == 0 {
|
|
metricDNSFwdErrorNoUpstream.Add(1)
|
|
f.health.SetUnhealthy(dnsForwarderFailing, health.Args{health.ArgDNSServers: ""})
|
|
f.logf("no upstream resolvers set, returning SERVFAIL")
|
|
|
|
// Attempt to recompile the DNS configuration
|
|
// If we are being asked to forward queries and we have no
|
|
// nameservers, the network is in a bad state.
|
|
if f.missingUpstreamRecovery != nil {
|
|
f.missingUpstreamRecovery()
|
|
}
|
|
|
|
res, err := servfailResponse(query)
|
|
if err != nil {
|
|
f.logf("building servfail response: %v", err)
|
|
// Returning an error will cause an internal retry, there is
|
|
// nothing we can do if parsing failed. Just drop the packet.
|
|
return nil
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return fmt.Errorf("waiting to send SERVFAIL: %w", ctx.Err())
|
|
case responseChan <- res:
|
|
return nil
|
|
}
|
|
} else {
|
|
f.health.SetHealthy(dnsForwarderFailing)
|
|
}
|
|
}
|
|
|
|
fq := &forwardQuery{
|
|
txid: getTxID(query.bs),
|
|
packet: query.bs,
|
|
family: query.family,
|
|
closeOnCtxDone: new(closePool),
|
|
}
|
|
defer fq.closeOnCtxDone.Close()
|
|
|
|
if verboseDNSForward() {
|
|
domainSha256 := sha256.Sum256([]byte(domain))
|
|
domainSig := base64.RawStdEncoding.EncodeToString(domainSha256[:3])
|
|
f.logf("request(%d, %v, %d, %s) %d...", fq.txid, typ, len(domain), domainSig, len(fq.packet))
|
|
}
|
|
|
|
resc := make(chan []byte, 1) // it's fine buffered or not
|
|
errc := make(chan error, 1) // it's fine buffered or not too
|
|
for i := range resolvers {
|
|
go func(rr *resolverAndDelay) {
|
|
if rr.startDelay > 0 {
|
|
timer := time.NewTimer(rr.startDelay)
|
|
select {
|
|
case <-timer.C:
|
|
case <-ctx.Done():
|
|
timer.Stop()
|
|
return
|
|
}
|
|
}
|
|
resb, err := f.send(ctx, fq, *rr)
|
|
if err != nil {
|
|
err = fmt.Errorf("resolving using %q: %w", rr.name.Addr, err)
|
|
select {
|
|
case errc <- err:
|
|
case <-ctx.Done():
|
|
}
|
|
return
|
|
}
|
|
select {
|
|
case resc <- resb:
|
|
case <-ctx.Done():
|
|
}
|
|
}(&resolvers[i])
|
|
}
|
|
|
|
var firstErr error
|
|
var numErr int
|
|
for {
|
|
select {
|
|
case v := <-resc:
|
|
select {
|
|
case <-ctx.Done():
|
|
metricDNSFwdErrorContext.Add(1)
|
|
return fmt.Errorf("waiting to send response: %w", ctx.Err())
|
|
case responseChan <- packet{v, query.family, query.addr}:
|
|
if verboseDNSForward() {
|
|
f.logf("response(%d, %v, %d) = %d, nil", fq.txid, typ, len(domain), len(v))
|
|
}
|
|
metricDNSFwdSuccess.Add(1)
|
|
f.health.SetHealthy(dnsForwarderFailing)
|
|
return nil
|
|
}
|
|
case err := <-errc:
|
|
if firstErr == nil {
|
|
firstErr = err
|
|
}
|
|
numErr++
|
|
if numErr == len(resolvers) {
|
|
if errors.Is(firstErr, errServerFailure) {
|
|
res, err := servfailResponse(query)
|
|
if err != nil {
|
|
f.logf("building servfail response: %v", err)
|
|
return firstErr
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
metricDNSFwdErrorContext.Add(1)
|
|
metricDNSFwdErrorContextGotError.Add(1)
|
|
var resolverAddrs []string
|
|
for _, rr := range resolvers {
|
|
resolverAddrs = append(resolverAddrs, rr.name.Addr)
|
|
}
|
|
f.health.SetUnhealthy(dnsForwarderFailing, health.Args{health.ArgDNSServers: strings.Join(resolverAddrs, ",")})
|
|
case responseChan <- res:
|
|
if verboseDNSForward() {
|
|
f.logf("forwarder response(%d, %v, %d) = %d, %v", fq.txid, typ, len(domain), len(res.bs), firstErr)
|
|
}
|
|
}
|
|
}
|
|
return firstErr
|
|
}
|
|
case <-ctx.Done():
|
|
metricDNSFwdErrorContext.Add(1)
|
|
if firstErr != nil {
|
|
metricDNSFwdErrorContextGotError.Add(1)
|
|
return firstErr
|
|
}
|
|
|
|
// If we haven't got an error or a successful response,
|
|
// include all resolvers in the error message so we can
|
|
// at least see what what servers we're trying to
|
|
// query.
|
|
var resolverAddrs []string
|
|
for _, rr := range resolvers {
|
|
resolverAddrs = append(resolverAddrs, rr.name.Addr)
|
|
}
|
|
f.health.SetUnhealthy(dnsForwarderFailing, health.Args{health.ArgDNSServers: strings.Join(resolverAddrs, ",")})
|
|
return fmt.Errorf("waiting for response or error from %v: %w", resolverAddrs, ctx.Err())
|
|
}
|
|
}
|
|
}
|
|
|
|
var initListenConfig func(_ *net.ListenConfig, _ *netmon.Monitor, tunName string) error
|
|
|
|
// nameFromQuery extracts the normalized query name from bs.
|
|
func nameFromQuery(bs []byte) (dnsname.FQDN, dns.Type, error) {
|
|
var parser dns.Parser
|
|
|
|
hdr, err := parser.Start(bs)
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
if hdr.Response {
|
|
return "", 0, errNotQuery
|
|
}
|
|
|
|
q, err := parser.Question()
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
|
|
n := q.Name.Data[:q.Name.Length]
|
|
fqdn, err := dnsname.ToFQDN(rawNameToLower(n))
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
return fqdn, q.Type, nil
|
|
}
|
|
|
|
// nxDomainResponse returns an NXDomain DNS reply for the provided request.
|
|
func nxDomainResponse(req packet) (res packet, err error) {
|
|
p := dnsParserPool.Get().(*dnsParser)
|
|
defer dnsParserPool.Put(p)
|
|
|
|
if err := p.parseQuery(req.bs); err != nil {
|
|
return packet{}, err
|
|
}
|
|
|
|
h := p.Header
|
|
h.Response = true
|
|
h.RecursionAvailable = h.RecursionDesired
|
|
h.RCode = dns.RCodeNameError
|
|
b := dns.NewBuilder(nil, h)
|
|
// TODO(bradfitz): should we add an SOA record in the Authority
|
|
// section too? (for the nxdomain negative caching TTL)
|
|
// For which zone? Does iOS care?
|
|
b.StartQuestions()
|
|
b.Question(p.Question)
|
|
res.bs, err = b.Finish()
|
|
res.addr = req.addr
|
|
return res, err
|
|
}
|
|
|
|
// servfailResponse returns a SERVFAIL error reply for the provided request.
|
|
func servfailResponse(req packet) (res packet, err error) {
|
|
p := dnsParserPool.Get().(*dnsParser)
|
|
defer dnsParserPool.Put(p)
|
|
|
|
if err := p.parseQuery(req.bs); err != nil {
|
|
return packet{}, err
|
|
}
|
|
|
|
h := p.Header
|
|
h.Response = true
|
|
h.Authoritative = true
|
|
h.RCode = dns.RCodeServerFailure
|
|
b := dns.NewBuilder(nil, h)
|
|
b.StartQuestions()
|
|
b.Question(p.Question)
|
|
res.bs, err = b.Finish()
|
|
res.addr = req.addr
|
|
return res, err
|
|
}
|
|
|
|
// closePool is a dynamic set of io.Closers to close as a group.
|
|
// It's intended to be Closed at most once.
|
|
//
|
|
// The zero value is ready for use.
|
|
type closePool struct {
|
|
mu sync.Mutex
|
|
m map[io.Closer]bool
|
|
closed bool
|
|
}
|
|
|
|
func (p *closePool) Add(c io.Closer) {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
if p.closed {
|
|
c.Close()
|
|
return
|
|
}
|
|
if p.m == nil {
|
|
p.m = map[io.Closer]bool{}
|
|
}
|
|
p.m[c] = true
|
|
}
|
|
|
|
func (p *closePool) Remove(c io.Closer) {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
if p.closed {
|
|
return
|
|
}
|
|
delete(p.m, c)
|
|
}
|
|
|
|
func (p *closePool) Close() error {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
if p.closed {
|
|
return nil
|
|
}
|
|
p.closed = true
|
|
for c := range p.m {
|
|
c.Close()
|
|
}
|
|
return nil
|
|
}
|