From e3a49525277b4ea40780e6d120d17d60a30925d8 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 19 Apr 2022 10:58:52 -0700 Subject: [PATCH] net/dns/resolver: count errors when racing DNS queries, fail earlier If all N queries failed, we waited until context timeout (in 5 seconds) to return. This makes (*forwarder).forward fail fast when the network's unavailable. Change-Id: Ibbb3efea7ed34acd3f3b29b5fee00ba8c7492569 Signed-off-by: Brad Fitzpatrick --- net/dns/resolver/forwarder.go | 59 ++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 25 deletions(-) diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index addfea4e0..2741d6349 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -615,6 +615,10 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo return err } + // Guarantee that the ctx we use below is done when this function returns. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + // Drop DNS service discovery spam, primarily for battery life // on mobile. Things like Spotify on iOS generate this traffic, // when browsing for LAN devices. But even when filtering this @@ -655,12 +659,8 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo } defer fq.closeOnCtxDone.Close() - resc := make(chan []byte, 1) - var ( - mu sync.Mutex - firstErr error - ) - + resc := make(chan []byte, 1) // it's fine buffered or not + errc := make(chan error, 1) // it's fine buffered or not too for i := range resolvers { go func(rr *resolverAndDelay) { if rr.startDelay > 0 { @@ -674,39 +674,48 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo } resb, err := f.send(ctx, fq, *rr) if err != nil { - mu.Lock() - defer mu.Unlock() - if firstErr == nil { - firstErr = err + select { + case errc <- err: + case <-ctx.Done(): } return } select { case resc <- resb: - default: + case <-ctx.Done(): } }(&resolvers[i]) } - select { - case v := <-resc: + var firstErr error + var numErr int + for { select { + case v := <-resc: + select { + case <-ctx.Done(): + metricDNSFwdErrorContext.Add(1) + return ctx.Err() + case responseChan <- packet{v, query.addr}: + metricDNSFwdSuccess.Add(1) + return nil + } + case err := <-errc: + if firstErr == nil { + firstErr = err + } + numErr++ + if numErr == len(resolvers) { + return firstErr + } case <-ctx.Done(): metricDNSFwdErrorContext.Add(1) + if firstErr != nil { + metricDNSFwdErrorContextGotError.Add(1) + return firstErr + } return ctx.Err() - case responseChan <- packet{v, query.addr}: - metricDNSFwdSuccess.Add(1) - return nil } - case <-ctx.Done(): - mu.Lock() - defer mu.Unlock() - metricDNSFwdErrorContext.Add(1) - if firstErr != nil { - metricDNSFwdErrorContextGotError.Add(1) - return firstErr - } - return ctx.Err() } }