net/dns/resolver: unexport Packet, only use it internally.

Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
David Anderson 2021-03-31 23:06:47 -07:00
parent 5fb9e00ecf
commit f185d62dc8
4 changed files with 41 additions and 55 deletions

View File

@ -105,7 +105,7 @@ type forwarder struct {
logf logger.Logf logf logger.Logf
// responses is a channel by which responses are returned. // responses is a channel by which responses are returned.
responses chan Packet responses chan packet
// closed signals all goroutines to stop. // closed signals all goroutines to stop.
closed chan struct{} closed chan struct{}
// wg signals when all goroutines have stopped. // wg signals when all goroutines have stopped.
@ -126,7 +126,7 @@ func init() {
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
} }
func newForwarder(logf logger.Logf, responses chan Packet) *forwarder { func newForwarder(logf logger.Logf, responses chan packet) *forwarder {
return &forwarder{ return &forwarder{
logf: logger.WithPrefix(logf, "forward: "), logf: logger.WithPrefix(logf, "forward: "),
responses: responses, responses: responses,
@ -218,14 +218,11 @@ func (f *forwarder) recv(conn *fwdConn) {
f.mu.Unlock() f.mu.Unlock()
packet := Packet{ pkt := packet{out, record.src}
Payload: out,
Addr: record.src,
}
select { select {
case <-f.closed: case <-f.closed:
return return
case f.responses <- packet: case f.responses <- pkt:
// continue // continue
} }
} }
@ -258,8 +255,8 @@ func (f *forwarder) cleanMap() {
} }
// forward forwards the query to all upstream nameservers and returns the first response. // forward forwards the query to all upstream nameservers and returns the first response.
func (f *forwarder) forward(query Packet) error { func (f *forwarder) forward(query packet) error {
txid := getTxID(query.Payload) txid := getTxID(query.bs)
f.mu.Lock() f.mu.Lock()
@ -269,14 +266,14 @@ func (f *forwarder) forward(query Packet) error {
return errNoUpstreams return errNoUpstreams
} }
f.txMap[txid] = forwardingRecord{ f.txMap[txid] = forwardingRecord{
src: query.Addr, src: query.addr,
createdAt: time.Now(), createdAt: time.Now(),
} }
f.mu.Unlock() f.mu.Unlock()
for _, upstream := range upstreams { for _, upstream := range upstreams {
f.send(query.Payload, upstream) f.send(query.bs, upstream)
} }
return nil return nil

View File

@ -44,14 +44,9 @@
errNotOurName = errors.New("not a Tailscale DNS name") errNotOurName = errors.New("not a Tailscale DNS name")
) )
// Packet represents a DNS payload together with the address of its origin. type packet struct {
type Packet struct { bs []byte
// Payload is the application layer DNS payload. addr netaddr.IPPort // src for a request, dst for a response
// Resolver assumes ownership of the request payload when it is enqueued
// and cedes ownership of the response payload when it is returned from NextResponse.
Payload []byte
// Addr is the source address for a request and the destination address for a response.
Addr netaddr.IPPort
} }
// Resolver is a DNS resolver for nodes on the Tailscale network, // Resolver is a DNS resolver for nodes on the Tailscale network,
@ -66,9 +61,9 @@ type Resolver struct {
forwarder *forwarder forwarder *forwarder
// queue is a buffered channel holding DNS requests queued for resolution. // queue is a buffered channel holding DNS requests queued for resolution.
queue chan Packet queue chan packet
// responses is an unbuffered channel to which responses are returned. // responses is an unbuffered channel to which responses are returned.
responses chan Packet responses chan packet
// errors is an unbuffered channel to which errors are returned. // errors is an unbuffered channel to which errors are returned.
errors chan error errors chan error
// closed signals all goroutines to stop. // closed signals all goroutines to stop.
@ -88,8 +83,8 @@ func New(logf logger.Logf, linkMon *monitor.Mon) (*Resolver, error) {
r := &Resolver{ r := &Resolver{
logf: logger.WithPrefix(logf, "dns: "), logf: logger.WithPrefix(logf, "dns: "),
linkMon: linkMon, linkMon: linkMon,
queue: make(chan Packet, queueSize), queue: make(chan packet, queueSize),
responses: make(chan Packet), responses: make(chan packet),
errors: make(chan error), errors: make(chan error),
closed: make(chan struct{}), closed: make(chan struct{}),
} }
@ -153,11 +148,11 @@ func (r *Resolver) SetUpstreams(upstreams []net.Addr) {
// EnqueueRequest places the given DNS request in the resolver's queue. // EnqueueRequest places the given DNS request in the resolver's queue.
// It takes ownership of the payload and does not block. // It takes ownership of the payload and does not block.
// If the queue is full, the request will be dropped and an error will be returned. // If the queue is full, the request will be dropped and an error will be returned.
func (r *Resolver) EnqueueRequest(request Packet) error { func (r *Resolver) EnqueueRequest(bs []byte, from netaddr.IPPort) error {
select { select {
case <-r.closed: case <-r.closed:
return ErrClosed return ErrClosed
case r.queue <- request: case r.queue <- packet{bs, from}:
return nil return nil
default: default:
return errFullQueue return errFullQueue
@ -166,14 +161,14 @@ func (r *Resolver) EnqueueRequest(request Packet) error {
// NextResponse returns a DNS response to a previously enqueued request. // NextResponse returns a DNS response to a previously enqueued request.
// It blocks until a response is available and gives up ownership of the response payload. // It blocks until a response is available and gives up ownership of the response payload.
func (r *Resolver) NextResponse() (Packet, error) { func (r *Resolver) NextResponse() (packet []byte, to netaddr.IPPort, err error) {
select { select {
case <-r.closed: case <-r.closed:
return Packet{}, ErrClosed return nil, netaddr.IPPort{}, ErrClosed
case resp := <-r.responses: case resp := <-r.responses:
return resp, nil return resp.bs, resp.addr, nil
case err := <-r.errors: case err := <-r.errors:
return Packet{}, err return nil, netaddr.IPPort{}, err
} }
} }
@ -266,19 +261,19 @@ func (r *Resolver) ResolveReverse(ip netaddr.IP) (string, dns.RCode, error) {
func (r *Resolver) poll() { func (r *Resolver) poll() {
defer r.wg.Done() defer r.wg.Done()
var packet Packet var pkt packet
for { for {
select { select {
case <-r.closed: case <-r.closed:
return return
case packet = <-r.queue: case pkt = <-r.queue:
// continue // continue
} }
out, err := r.respond(packet.Payload) out, err := r.respond(pkt.bs)
if err == errNotOurName { if err == errNotOurName {
err = r.forwarder.forward(packet) err = r.forwarder.forward(pkt)
if err == nil { if err == nil {
// forward will send response into r.responses, nothing to do. // forward will send response into r.responses, nothing to do.
continue continue
@ -293,11 +288,11 @@ func (r *Resolver) poll() {
// continue // continue
} }
} else { } else {
packet.Payload = out pkt.bs = out
select { select {
case <-r.closed: case <-r.closed:
return return
case r.responses <- packet: case r.responses <- pkt:
// continue // continue
} }
} }

View File

@ -109,10 +109,9 @@ func unpackResponse(payload []byte) (dnsResponse, error) {
} }
func syncRespond(r *Resolver, query []byte) ([]byte, error) { func syncRespond(r *Resolver, query []byte) ([]byte, error) {
request := Packet{Payload: query} r.EnqueueRequest(query, netaddr.IPPort{})
r.EnqueueRequest(request) payload, _, err := r.NextResponse()
resp, err := r.NextResponse() return payload, err
return resp.Payload, err
} }
func mustIP(str string) netaddr.IP { func mustIP(str string) netaddr.IP {
@ -418,21 +417,20 @@ func TestDelegateCollision(t *testing.T) {
// packets will have the same dns txid. // packets will have the same dns txid.
for _, p := range packets { for _, p := range packets {
payload := dnspacket(p.qname, p.qtype) payload := dnspacket(p.qname, p.qtype)
req := Packet{Payload: payload, Addr: p.addr} err := r.EnqueueRequest(payload, p.addr)
err := r.EnqueueRequest(req)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
} }
// Despite the txid collision, the answer(s) should still match the query. // Despite the txid collision, the answer(s) should still match the query.
resp, err := r.NextResponse() resp, addr, err := r.NextResponse()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
var p dns.Parser var p dns.Parser
_, err = p.Start(resp.Payload) _, err = p.Start(resp)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -456,8 +454,8 @@ func TestDelegateCollision(t *testing.T) {
} }
for _, p := range packets { for _, p := range packets {
if p.qtype == wantType && p.addr != resp.Addr { if p.qtype == wantType && p.addr != addr {
t.Errorf("addr = %v; want %v", resp.Addr, p.addr) t.Errorf("addr = %v; want %v", addr, p.addr)
} }
} }
} }

View File

@ -433,11 +433,7 @@ func (e *userspaceEngine) handleLocalPackets(p *packet.Parsed, t *tstun.Wrapper)
// handleDNS is an outbound pre-filter resolving Tailscale domains. // handleDNS is an outbound pre-filter resolving Tailscale domains.
func (e *userspaceEngine) handleDNS(p *packet.Parsed, t *tstun.Wrapper) filter.Response { func (e *userspaceEngine) handleDNS(p *packet.Parsed, t *tstun.Wrapper) filter.Response {
if p.Dst.IP == magicDNSIP && p.Dst.Port == magicDNSPort && p.IPProto == ipproto.UDP { if p.Dst.IP == magicDNSIP && p.Dst.Port == magicDNSPort && p.IPProto == ipproto.UDP {
request := resolver.Packet{ err := e.resolver.EnqueueRequest(append([]byte(nil), p.Payload()...), p.Src)
Payload: append([]byte(nil), p.Payload()...),
Addr: netaddr.IPPort{IP: p.Src.IP, Port: p.Src.Port},
}
err := e.resolver.EnqueueRequest(request)
if err != nil { if err != nil {
e.logf("dns: enqueue: %v", err) e.logf("dns: enqueue: %v", err)
} }
@ -449,7 +445,7 @@ func (e *userspaceEngine) handleDNS(p *packet.Parsed, t *tstun.Wrapper) filter.R
// pollResolver reads responses from the DNS resolver and injects them inbound. // pollResolver reads responses from the DNS resolver and injects them inbound.
func (e *userspaceEngine) pollResolver() { func (e *userspaceEngine) pollResolver() {
for { for {
resp, err := e.resolver.NextResponse() bs, to, err := e.resolver.NextResponse()
if err == resolver.ErrClosed { if err == resolver.ErrClosed {
return return
} }
@ -461,17 +457,17 @@ func (e *userspaceEngine) pollResolver() {
h := packet.UDP4Header{ h := packet.UDP4Header{
IP4Header: packet.IP4Header{ IP4Header: packet.IP4Header{
Src: magicDNSIP, Src: magicDNSIP,
Dst: resp.Addr.IP, Dst: to.IP,
}, },
SrcPort: magicDNSPort, SrcPort: magicDNSPort,
DstPort: resp.Addr.Port, DstPort: to.Port,
} }
hlen := h.Len() hlen := h.Len()
// TODO(dmytro): avoid this allocation without importing tstun quirks into dns. // TODO(dmytro): avoid this allocation without importing tstun quirks into dns.
const offset = tstun.PacketStartOffset const offset = tstun.PacketStartOffset
buf := make([]byte, offset+hlen+len(resp.Payload)) buf := make([]byte, offset+hlen+len(bs))
copy(buf[offset+hlen:], resp.Payload) copy(buf[offset+hlen:], bs)
h.Marshal(buf[offset:]) h.Marshal(buf[offset:])
e.tundev.InjectInboundDirect(buf, offset) e.tundev.InjectInboundDirect(buf, offset)