diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index addfea4e0..2741d6349 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -615,6 +615,10 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo return err } + // Guarantee that the ctx we use below is done when this function returns. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + // Drop DNS service discovery spam, primarily for battery life // on mobile. Things like Spotify on iOS generate this traffic, // when browsing for LAN devices. But even when filtering this @@ -655,12 +659,8 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo } defer fq.closeOnCtxDone.Close() - resc := make(chan []byte, 1) - var ( - mu sync.Mutex - firstErr error - ) - + resc := make(chan []byte, 1) // it's fine buffered or not + errc := make(chan error, 1) // it's fine buffered or not too for i := range resolvers { go func(rr *resolverAndDelay) { if rr.startDelay > 0 { @@ -674,39 +674,48 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo } resb, err := f.send(ctx, fq, *rr) if err != nil { - mu.Lock() - defer mu.Unlock() - if firstErr == nil { - firstErr = err + select { + case errc <- err: + case <-ctx.Done(): } return } select { case resc <- resb: - default: + case <-ctx.Done(): } }(&resolvers[i]) } - select { - case v := <-resc: + var firstErr error + var numErr int + for { select { + case v := <-resc: + select { + case <-ctx.Done(): + metricDNSFwdErrorContext.Add(1) + return ctx.Err() + case responseChan <- packet{v, query.addr}: + metricDNSFwdSuccess.Add(1) + return nil + } + case err := <-errc: + if firstErr == nil { + firstErr = err + } + numErr++ + if numErr == len(resolvers) { + return firstErr + } case <-ctx.Done(): metricDNSFwdErrorContext.Add(1) + if firstErr != nil { + metricDNSFwdErrorContextGotError.Add(1) + return firstErr + } return ctx.Err() - case responseChan <- packet{v, query.addr}: - metricDNSFwdSuccess.Add(1) - return nil } - case <-ctx.Done(): - mu.Lock() - defer mu.Unlock() - metricDNSFwdErrorContext.Add(1) - if firstErr != nil { - metricDNSFwdErrorContextGotError.Add(1) - return firstErr - } - return ctx.Err() } }