diff --git a/wgengine/tsdns/forwarder.go b/wgengine/tsdns/forwarder.go new file mode 100644 index 000000000..c43873a2f --- /dev/null +++ b/wgengine/tsdns/forwarder.go @@ -0,0 +1,325 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tsdns + +import ( + "bytes" + "encoding/binary" + "errors" + "hash/crc32" + "math/rand" + "net" + "os" + "sync" + "time" + + "inet.af/netaddr" + "tailscale.com/types/logger" +) + +// headerBytes is the number of bytes in a DNS message header. +const headerBytes = 12 + +// forwardQueueSize is the maximal number of requests that can be pending delegation. +// Note that this is distinct from the number of requests that are pending a response, +// which is not limited (except by txid collisions). +const forwardQueueSize = 64 + +// connCount is the number of UDP connections to use for forwarding. +const connCount = 32 + +const ( + // cleanupInterval is the interval between purged of timed-out entries from txMap. + cleanupInterval = 30 * time.Second + // responseTimeout is the maximal amount of time to wait for a DNS response. + responseTimeout = 5 * time.Second +) + +var errNoUpstreams = errors.New("upstream nameservers not set") + +var aLongTimeAgo = time.Unix(0, 1) + +type forwardedPacket struct { + payload []byte + dst net.Addr +} + +type forwardingRecord struct { + src netaddr.IPPort + createdAt time.Time +} + +// txid identifies a DNS transaction. +// +// As the standard DNS Request ID is only 16 bits, we extend it: +// the lower 32 bits are the zero-extended bits of the DNS Request ID; +// the upper 32 bits are the CRC32 checksum of the first question in the request. +// This makes probability of txid collision negligible. +type txid uint64 + +// getTxID computes the txid of the given DNS packet. +func getTxID(packet []byte) txid { + if len(packet) < headerBytes { + return 0 + } + + dnsid := binary.BigEndian.Uint16(packet[0:2]) + qcount := binary.BigEndian.Uint16(packet[4:6]) + if qcount == 0 { + return txid(dnsid) + } + + offset := headerBytes + for i := uint16(0); i < qcount; i++ { + // Note: this relies on the fact that names are not compressed in questions, + // so they are guaranteed to end with a NUL byte. + // + // Justification: + // RFC 1035 doesn't seem to explicitly prohibit compressing names in questions, + // but this is exceedingly unlikely to be done in practice. A DNS request + // with multiple questions is ill-defined (which questions do the header flags apply to?) + // and a single question would have to contain a pointer to an *answer*, + // which would be excessively smart, pointless (an answer can just as well refer to the question) + // and perhaps even prohibited: a draft RFC (draft-ietf-dnsind-local-compression-05) states: + // + // > It is important that these pointers always point backwards. + // + // This is said in summarizing RFC 1035, although that phrase does not appear in the original RFC. + // Additionally, (https://cr.yp.to/djbdns/notes.html) states: + // + // > The precise rule is that a name can be compressed if it is a response owner name, + // > the name in NS data, the name in CNAME data, the name in PTR data, the name in MX data, + // > or one of the names in SOA data. + namebytes := bytes.IndexByte(packet[offset:], 0) + // ... | name | NUL | type | class + // ?? 1 2 2 + offset = offset + namebytes + 5 + if len(packet) < offset { + // Corrupt packet; don't crash. + return txid(dnsid) + } + } + + hash := crc32.ChecksumIEEE(packet[headerBytes:offset]) + return (txid(hash) << 32) | txid(dnsid) +} + +// forwarder forwards DNS packets to a number of upstream nameservers. +type forwarder struct { + logf logger.Logf + + // queue is the queue for delegated packets. + queue chan forwardedPacket + // responses is a channel by which responses are returned. + responses chan Packet + // closed signals all goroutines to stop. + closed chan struct{} + // wg signals when all goroutines have stopped. + wg sync.WaitGroup + + // conns are the UDP connections used for forwarding. + // A random one is selected for each request, regardless of the target upstream. + conns []*net.UDPConn + + mu sync.Mutex + // upstreams are the nameserver addresses that should be used for forwarding. + upstreams []net.Addr + // txMap maps DNS txids to active forwarding records. + txMap map[txid]forwardingRecord +} + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func newForwarder(logf logger.Logf, responses chan Packet) *forwarder { + return &forwarder{ + logf: logger.WithPrefix(logf, "forward: "), + responses: responses, + queue: make(chan forwardedPacket, forwardQueueSize), + closed: make(chan struct{}), + conns: make([]*net.UDPConn, connCount), + txMap: make(map[txid]forwardingRecord), + } +} + +func (f *forwarder) Start() error { + var err error + + for i := range f.conns { + f.conns[i], err = net.ListenUDP("udp", nil) + if err != nil { + return err + } + } + + f.wg.Add(connCount + 2) + for idx, conn := range f.conns { + go f.recv(uint16(idx), conn) + } + go f.send() + go f.cleanMap() + + return nil +} + +func (f *forwarder) Close() { + select { + case <-f.closed: + return + default: + // continue + } + close(f.closed) + + for _, conn := range f.conns { + conn.SetDeadline(aLongTimeAgo) + } + + f.wg.Wait() + + for _, conn := range f.conns { + conn.Close() + } +} + +func (f *forwarder) setUpstreams(upstreams []net.Addr) { + f.mu.Lock() + f.upstreams = upstreams + f.mu.Unlock() +} + +func (f *forwarder) send() { + defer f.wg.Done() + + var packet forwardedPacket + for { + select { + case <-f.closed: + return + case packet = <-f.queue: + // continue + } + + connIdx := rand.Intn(connCount) + conn := f.conns[connIdx] + _, err := conn.WriteTo(packet.payload, packet.dst) + if err != nil { + // Do not log errors due to expired deadline. + if !errors.Is(err, os.ErrDeadlineExceeded) { + f.logf("send: %v", err) + } + return + } + } +} + +func (f *forwarder) recv(connIdx uint16, conn *net.UDPConn) { + defer f.wg.Done() + + for { + out := make([]byte, maxResponseBytes) + n, err := conn.Read(out) + + if err != nil { + // Do not log errors due to expired deadline. + if !errors.Is(err, os.ErrDeadlineExceeded) { + f.logf("recv: %v", err) + } + return + } + + if n < headerBytes { + f.logf("recv: packet too small (%d bytes)", n) + } + + out = out[:n] + txid := getTxID(out) + + f.mu.Lock() + + record, found := f.txMap[txid] + // At most one nameserver will return a response: + // the first one to do so will delete txid from the map. + if !found { + f.mu.Unlock() + continue + } + delete(f.txMap, txid) + + f.mu.Unlock() + + packet := Packet{ + Payload: out, + Addr: record.src, + } + select { + case <-f.closed: + return + case f.responses <- packet: + // continue + } + } +} + +// cleanMap periodically deletes timed-out forwarding records from f.txMap to bound growth. +func (f *forwarder) cleanMap() { + defer f.wg.Done() + + t := time.NewTicker(cleanupInterval) + defer t.Stop() + + var now time.Time + for { + select { + case <-f.closed: + return + case now = <-t.C: + // continue + } + + f.mu.Lock() + for k, v := range f.txMap { + if now.Sub(v.createdAt) > responseTimeout { + delete(f.txMap, k) + } + } + f.mu.Unlock() + } +} + +// forward forwards the query to all upstream nameservers and returns the first response. +func (f *forwarder) forward(query Packet) error { + txid := getTxID(query.Payload) + + f.mu.Lock() + + upstreams := f.upstreams + if len(upstreams) == 0 { + f.mu.Unlock() + return errNoUpstreams + } + f.txMap[txid] = forwardingRecord{ + src: query.Addr, + createdAt: time.Now(), + } + + f.mu.Unlock() + + packet := forwardedPacket{ + payload: query.Payload, + } + for _, upstream := range upstreams { + packet.dst = upstream + select { + case <-f.closed: + return ErrClosed + case f.queue <- packet: + // continue + } + } + + return nil +} diff --git a/wgengine/tsdns/tsdns.go b/wgengine/tsdns/tsdns.go index 54ed11b4d..e4ecf90f6 100644 --- a/wgengine/tsdns/tsdns.go +++ b/wgengine/tsdns/tsdns.go @@ -8,29 +8,24 @@ package tsdns import ( "bytes" - "context" "encoding/hex" "errors" + "net" "sync" "time" dns "golang.org/x/net/dns/dnsmessage" "inet.af/netaddr" - "tailscale.com/net/netns" "tailscale.com/types/logger" ) -// maxResponseSize is the maximum size of a response from a Resolver. -const maxResponseSize = 512 +// maxResponseBytes is the maximum size of a response from a Resolver. +const maxResponseBytes = 512 -// queueSize is the maximal number of DNS requests that can be pending at a time. +// pendingQueueSize is the maximal number of DNS requests that can await polling. // If EnqueueRequest is called when this many requests are already pending, // the request will be dropped to avoid blocking the caller. -const queueSize = 8 - -// delegateTimeout is the maximal amount of time Resolver will wait -// for upstream nameservers to process a query. -const delegateTimeout = 5 * time.Second +const pendingQueueSize = 64 // defaultTTL is the TTL of all responses from Resolver. const defaultTTL = 600 * time.Second @@ -39,12 +34,12 @@ const defaultTTL = 600 * time.Second var ErrClosed = errors.New("closed") var ( - errAllFailed = errors.New("all upstream nameservers failed") errFullQueue = errors.New("request queue full") - errNoNameservers = errors.New("no upstream nameservers set") errMapNotSet = errors.New("domain map not set") + errNotForwarding = errors.New("forwarding disabled") errNotImplemented = errors.New("query type not implemented") errNotQuery = errors.New("not a DNS query") + errNotOurName = errors.New("not a Tailscale DNS name") ) // Packet represents a DNS payload together with the address of its origin. @@ -63,57 +58,69 @@ type Packet struct { // it delegates to upstream nameservers if any are set. type Resolver struct { logf logger.Logf - - // The asynchronous interface is due to the fact that resolution may potentially - // block for a long time (if the upstream nameserver is slow to reach). + // rootDomain is in ... + rootDomain []byte + // forwarder is + forwarder *forwarder // queue is a buffered channel holding DNS requests queued for resolution. queue chan Packet - // responses is an unbuffered channel to which responses are sent. + // responses is an unbuffered channel to which responses are returned. responses chan Packet - // errors is an unbuffered channel to which errors are sent. + // errors is an unbuffered channel to which errors are returned. errors chan error - // closed notifies the poll goroutines to stop. + // closed signals all goroutines to stop. closed chan struct{} - // pollGroup signals when all poll goroutines have stopped. - pollGroup sync.WaitGroup - - // rootDomain is in ... - rootDomain []byte - - // dialer is the netns.Dialer used for delegation. - dialer netns.Dialer + // wg signals when all goroutines have stopped. + wg sync.WaitGroup // mu guards the following fields from being updated while used. mu sync.Mutex // dnsMap is the map most recently received from the control server. dnsMap *Map - // nameservers is the list of nameserver addresses that should be used - // if the received query is not for a Tailscale node. - // The addresses are strings of the form ip:port, as expected by Dial. - nameservers []string +} + +// ResolverConfig is the set of configuration options for a Resolver. +type ResolverConfig struct { + // Logf is the logger to use throughout the Resolver. + Logf logger.Logf + // RootDomain is the domain whose subdomains will be resolved locally as Tailscale nodes. + RootDomain string + // Forward determines whether the resolver will forward packets to + // nameservers set with SetUpstreams if the domain name is not of a Tailscale node. + Forward bool } // NewResolver constructs a resolver associated with the given root domain. // The root domain must be in canonical form (with a trailing period). -func NewResolver(logf logger.Logf, rootDomain string) *Resolver { +func NewResolver(config ResolverConfig) *Resolver { r := &Resolver{ - logf: logger.WithPrefix(logf, "tsdns: "), - queue: make(chan Packet, queueSize), + logf: logger.WithPrefix(config.Logf, "tsdns: "), + queue: make(chan Packet, pendingQueueSize), responses: make(chan Packet), errors: make(chan error), closed: make(chan struct{}), - rootDomain: []byte(rootDomain), - dialer: netns.NewDialer(), + rootDomain: []byte(config.RootDomain), + } + + if config.Forward { + r.forwarder = newForwarder(r.logf, r.responses) } return r } -func (r *Resolver) Start() { - // TODO(dmytro): spawn more than one goroutine? They block on delegation. - r.pollGroup.Add(1) +func (r *Resolver) Start() error { + if r.forwarder != nil { + if err := r.forwarder.Start(); err != nil { + return err + } + } + + r.wg.Add(1) go r.poll() + + return nil } // Close shuts down the resolver and ensures poll goroutines have exited. @@ -126,7 +133,12 @@ func (r *Resolver) Close() { // continue } close(r.closed) - r.pollGroup.Wait() + + if r.forwarder != nil { + r.forwarder.Close() + } + + r.wg.Wait() } // SetMap sets the resolver's DNS map, taking ownership of it. @@ -138,14 +150,12 @@ func (r *Resolver) SetMap(m *Map) { r.logf("map diff:\n%s", m.PrettyDiffFrom(oldMap)) } -// SetUpstreamNameservers sets the addresses of the resolver's +// SetUpstreams sets the addresses of the resolver's // upstream nameservers, taking ownership of the argument. -// The addresses should be strings of the form ip:port, -// matching what Dial("udp", addr) expects as addr. -func (r *Resolver) SetNameservers(nameservers []string) { - r.mu.Lock() - r.nameservers = nameservers - r.mu.Unlock() +func (r *Resolver) SetUpstreams(upstreams []net.Addr) { + if r.forwarder != nil { + r.forwarder.setUpstreams(upstreams) + } } // EnqueueRequest places the given DNS request in the resolver's queue. @@ -153,6 +163,8 @@ func (r *Resolver) SetNameservers(nameservers []string) { // If the queue is full, the request will be dropped and an error will be returned. func (r *Resolver) EnqueueRequest(request Packet) error { select { + case <-r.closed: + return ErrClosed case r.queue <- request: return nil default: @@ -164,12 +176,12 @@ func (r *Resolver) EnqueueRequest(request Packet) error { // It blocks until a response is available and gives up ownership of the response payload. func (r *Resolver) NextResponse() (Packet, error) { select { + case <-r.closed: + return Packet{}, ErrClosed case resp := <-r.responses: return resp, nil case err := <-r.errors: return Packet{}, err - case <-r.closed: - return Packet{}, ErrClosed } } @@ -209,114 +221,50 @@ func (r *Resolver) ResolveReverse(ip netaddr.IP) (string, dns.RCode, error) { } func (r *Resolver) poll() { - defer r.pollGroup.Done() + defer r.wg.Done() - var ( - packet Packet - err error - ) + var packet Packet for { select { - case packet = <-r.queue: - // continue case <-r.closed: return + case packet = <-r.queue: + // continue + } + + out, err := r.respond(packet.Payload) + + if err == errNotOurName { + if r.forwarder != nil { + err = r.forwarder.forward(packet) + if err == nil { + // forward will send response into r.responses, nothing to do. + continue + } + } else { + err = errNotForwarding + } } - packet.Payload, err = r.respond(packet.Payload) if err != nil { select { + case <-r.closed: + return case r.errors <- err: // continue - case <-r.closed: - return } } else { + packet.Payload = out select { - case r.responses <- packet: - // continue case <-r.closed: return + case r.responses <- packet: + // continue } } } } -// queryServer obtains a DNS response by querying the given server. -func (r *Resolver) queryServer(ctx context.Context, server string, query []byte) ([]byte, error) { - conn, err := r.dialer.DialContext(ctx, "udp", server) - if err != nil { - return nil, err - } - defer conn.Close() - - // Interrupt the current operation when the context is cancelled. - go func() { - <-ctx.Done() - conn.SetDeadline(time.Unix(1, 0)) - }() - - _, err = conn.Write(query) - if err != nil { - return nil, err - } - - out := make([]byte, maxResponseSize) - n, err := conn.Read(out) - if err != nil { - return nil, err - } - - return out[:n], nil -} - -// delegate forwards the query to all upstream nameservers and returns the first response. -func (r *Resolver) delegate(query []byte) ([]byte, error) { - r.mu.Lock() - nameservers := r.nameservers - r.mu.Unlock() - - if len(nameservers) == 0 { - return nil, errNoNameservers - } - - ctx, cancel := context.WithTimeout(context.Background(), delegateTimeout) - defer cancel() - - // Common case, don't spawn goroutines. - if len(nameservers) == 1 { - return r.queryServer(ctx, nameservers[0], query) - } - - datach := make(chan []byte) - for _, server := range nameservers { - go func(s string) { - resp, err := r.queryServer(ctx, s, query) - // Only print errors not due to cancelation after first response. - if err != nil && ctx.Err() != context.Canceled { - r.logf("querying %s: %v", s, err) - } - - datach <- resp - }(server) - } - - var response []byte - for range nameservers { - cur := <-datach - if cur != nil && response == nil { - // Received first successful response - response = cur - cancel() - } - } - - if response == nil { - return nil, errAllFailed - } - return response, nil -} - type response struct { Header dns.Header Question dns.Question @@ -517,8 +465,6 @@ func rdnsNameToIPv6(name []byte) (ip netaddr.IP, ok bool) { func (r *Resolver) respondReverse(query []byte, resp *response) ([]byte, error) { name := resp.Question.Name.Data[:resp.Question.Name.Length] - shouldDelegate := false - var ip netaddr.IP var ok bool var err error @@ -528,7 +474,7 @@ func (r *Resolver) respondReverse(query []byte, resp *response) ([]byte, error) case bytes.HasSuffix(name, rdnsv6Suffix): ip, ok = rdnsNameToIPv6(name) default: - shouldDelegate = true + return nil, errNotOurName } // It is more likely that we failed in parsing the name than that it is actually malformed. @@ -536,25 +482,15 @@ func (r *Resolver) respondReverse(query []byte, resp *response) ([]byte, error) if !ok { // Without this conversion, escape analysis rules that resp escapes. r.logf("parsing rdns: malformed name: %s", resp.Question.Name.String()) - shouldDelegate = true + return nil, errNotOurName } - if !shouldDelegate { - resp.Name, resp.Header.RCode, err = r.ResolveReverse(ip) - if err != nil { - r.logf("resolving rdns: %v", ip, err) - } - shouldDelegate = (resp.Header.RCode == dns.RCodeNameError) + resp.Name, resp.Header.RCode, err = r.ResolveReverse(ip) + if err != nil { + r.logf("resolving rdns: %v", ip, err) } - - if shouldDelegate { - out, err := r.delegate(query) - if err != nil { - r.logf("delegating rdns: %v", err) - resp.Header.RCode = dns.RCodeServerFailure - return marshalResponse(resp) - } - return out, nil + if resp.Header.RCode == dns.RCodeNameError { + return nil, errNotOurName } return marshalResponse(resp) @@ -586,13 +522,7 @@ func (r *Resolver) respond(query []byte) ([]byte, error) { // We do this on bytes because Name.String() allocates. rawName := resp.Question.Name.Data[:resp.Question.Name.Length] if !bytes.HasSuffix(rawName, r.rootDomain) { - out, err := r.delegate(query) - if err != nil { - r.logf("delegating: %v", err) - resp.Header.RCode = dns.RCodeServerFailure - return marshalResponse(resp) - } - return out, nil + return nil, errNotOurName } switch resp.Question.Type { diff --git a/wgengine/tsdns/tsdns_test.go b/wgengine/tsdns/tsdns_test.go index 6844f78a4..72e4031bd 100644 --- a/wgengine/tsdns/tsdns_test.go +++ b/wgengine/tsdns/tsdns_test.go @@ -7,11 +7,13 @@ package tsdns import ( "bytes" "errors" + "net" "sync" "testing" dns "golang.org/x/net/dns/dnsmessage" "inet.af/netaddr" + "tailscale.com/tstest" ) var testipv4 = netaddr.IPv4(1, 2, 3, 4) @@ -178,9 +180,13 @@ func TestRDNSNameToIPv6(t *testing.T) { } func TestResolve(t *testing.T) { - r := NewResolver(t.Logf, "ipn.dev") + r := NewResolver(ResolverConfig{Logf: t.Logf, RootDomain: "ipn.dev.", Forward: false}) r.SetMap(dnsMap) - r.Start() + + if err := r.Start(); err != nil { + t.Fatalf("start: %v", err) + } + defer r.Close() tests := []struct { name string @@ -212,9 +218,13 @@ func TestResolve(t *testing.T) { } func TestResolveReverse(t *testing.T) { - r := NewResolver(t.Logf, "ipn.dev") + r := NewResolver(ResolverConfig{Logf: t.Logf, RootDomain: "ipn.dev.", Forward: false}) r.SetMap(dnsMap) - r.Start() + + if err := r.Start(); err != nil { + t.Fatalf("start: %v", err) + } + defer r.Close() tests := []struct { name string @@ -244,6 +254,9 @@ func TestResolveReverse(t *testing.T) { } func TestDelegate(t *testing.T) { + rc := tstest.NewResourceCheck() + defer rc.Assert(t) + dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6)) dnsHandleFunc("nxdomain.site.", resolveToNXDOMAIN) @@ -271,12 +284,16 @@ func TestDelegate(t *testing.T) { return } - r := NewResolver(t.Logf, "ipn.dev") - r.SetNameservers([]string{ - v4server.PacketConn.LocalAddr().String(), - v6server.PacketConn.LocalAddr().String(), + r := NewResolver(ResolverConfig{Logf: t.Logf, RootDomain: "ipn.dev.", Forward: true}) + r.SetUpstreams([]net.Addr{ + v4server.PacketConn.LocalAddr(), + v6server.PacketConn.LocalAddr(), }) - r.Start() + + if err := r.Start(); err != nil { + t.Fatalf("start: %v", err) + } + defer r.Close() tests := []struct { name string @@ -311,9 +328,92 @@ func TestDelegate(t *testing.T) { } } +func TestDelegateCollision(t *testing.T) { + dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6)) + + server, errch := serveDNS("127.0.0.1:0") + defer func() { + if err := <-errch; err != nil { + t.Errorf("server error: %v", err) + } + }() + + if server == nil { + return + } + defer server.Shutdown() + + r := NewResolver(ResolverConfig{Logf: t.Logf, RootDomain: "ipn.dev.", Forward: true}) + r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()}) + + if err := r.Start(); err != nil { + t.Fatalf("start: %v", err) + } + defer r.Close() + + packets := []struct { + qname string + qtype dns.Type + addr netaddr.IPPort + }{ + {"test.site.", dns.TypeA, netaddr.IPPort{IP: netaddr.IPv4(1, 1, 1, 1), Port: 1001}}, + {"test.site.", dns.TypeAAAA, netaddr.IPPort{IP: netaddr.IPv4(1, 1, 1, 1), Port: 1002}}, + } + + // packets will have the same dns txid. + for _, p := range packets { + payload := dnspacket(p.qname, p.qtype) + req := Packet{Payload: payload, Addr: p.addr} + err := r.EnqueueRequest(req) + if err != nil { + t.Error(err) + } + } + + // Despite the txid collision, the answer(s) should still match the query. + resp, err := r.NextResponse() + if err != nil { + t.Error(err) + } + + var p dns.Parser + _, err = p.Start(resp.Payload) + if err != nil { + t.Error(err) + } + err = p.SkipAllQuestions() + if err != nil { + t.Error(err) + } + ans, err := p.AllAnswers() + if err != nil { + t.Error(err) + } + + var wantType dns.Type + switch ans[0].Body.(type) { + case *dns.AResource: + wantType = dns.TypeA + case *dns.AAAAResource: + wantType = dns.TypeAAAA + default: + t.Errorf("unexpected answer type: %T", ans[0].Body) + } + + for _, p := range packets { + if p.qtype == wantType && p.addr != resp.Addr { + t.Errorf("addr = %v; want %v", resp.Addr, p.addr) + } + } +} + func TestConcurrentSetMap(t *testing.T) { - r := NewResolver(t.Logf, "ipn.dev.") - r.Start() + r := NewResolver(ResolverConfig{Logf: t.Logf, RootDomain: "ipn.dev.", Forward: false}) + + if err := r.Start(); err != nil { + t.Fatalf("start: %v", err) + } + defer r.Close() // This is purely to ensure that Resolve does not race with SetMap. var wg sync.WaitGroup @@ -329,17 +429,36 @@ func TestConcurrentSetMap(t *testing.T) { wg.Wait() } -func TestConcurrentSetNameservers(t *testing.T) { - r := NewResolver(t.Logf, "ipn.dev.") - r.Start() - packet := dnspacket("google.com.", dns.TypeA) +func TestConcurrentSetUpstreams(t *testing.T) { + dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6)) - // This is purely to ensure that delegation does not race with SetNameservers. + server, errch := serveDNS("127.0.0.1:0") + defer func() { + if err := <-errch; err != nil { + t.Errorf("server error: %v", err) + } + }() + + if server == nil { + return + } + defer server.Shutdown() + + r := NewResolver(ResolverConfig{Logf: t.Logf, RootDomain: "ipn.dev.", Forward: true}) + r.SetMap(dnsMap) + + if err := r.Start(); err != nil { + t.Fatalf("start: %v", err) + } + defer r.Close() + + packet := dnspacket("test.site.", dns.TypeA) + // This is purely to ensure that delegation does not race with SetUpstreams. var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() - r.SetNameservers([]string{"9.9.9.9:53"}) + r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()}) }() go func() { defer wg.Done() @@ -415,9 +534,13 @@ var nxdomainResponse = []byte{ } func TestFull(t *testing.T) { - r := NewResolver(t.Logf, "ipn.dev.") + r := NewResolver(ResolverConfig{Logf: t.Logf, RootDomain: "ipn.dev.", Forward: false}) r.SetMap(dnsMap) - r.Start() + + if err := r.Start(); err != nil { + t.Fatalf("start: %v", err) + } + defer r.Close() // One full packet and one error packet tests := []struct { @@ -445,9 +568,13 @@ func TestFull(t *testing.T) { } func TestAllocs(t *testing.T) { - r := NewResolver(t.Logf, "ipn.dev.") + r := NewResolver(ResolverConfig{Logf: t.Logf, RootDomain: "ipn.dev.", Forward: false}) r.SetMap(dnsMap) - r.Start() + + if err := r.Start(); err != nil { + t.Fatalf("start: %v", err) + } + defer r.Close() // It is seemingly pointless to test allocs in the delegate path, // as dialer.Dial -> Read -> Write alone comprise 12 allocs. @@ -473,9 +600,28 @@ func TestAllocs(t *testing.T) { } func BenchmarkFull(b *testing.B) { - r := NewResolver(b.Logf, "ipn.dev.") + dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6)) + + server, errch := serveDNS("127.0.0.1:0") + defer func() { + if err := <-errch; err != nil { + b.Errorf("server error: %v", err) + } + }() + + if server == nil { + return + } + defer server.Shutdown() + + r := NewResolver(ResolverConfig{Logf: b.Logf, RootDomain: "ipn.dev.", Forward: true}) r.SetMap(dnsMap) - r.Start() + r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()}) + + if err := r.Start(); err != nil { + b.Fatalf("start: %v", err) + } + defer r.Close() tests := []struct { name string @@ -483,7 +629,7 @@ func BenchmarkFull(b *testing.B) { }{ {"forward", dnspacket("test1.ipn.dev.", dns.TypeA)}, {"reverse", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR)}, - {"nxdomain", dnspacket("test3.ipn.dev.", dns.TypeA)}, + {"delegated", dnspacket("test.site.", dns.TypeA)}, } for _, tt := range tests { diff --git a/wgengine/userspace.go b/wgengine/userspace.go index b775b7cb1..a11fa90cd 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -201,13 +201,18 @@ func NewUserspaceEngineAdvanced(conf EngineConfig) (Engine, error) { func newUserspaceEngineAdvanced(conf EngineConfig) (_ Engine, reterr error) { logf := conf.Logf + rconf := tsdns.ResolverConfig{ + Logf: conf.Logf, + RootDomain: magicDNSDomain, + Forward: true, + } e := &userspaceEngine{ timeNow: time.Now, logf: logf, reqCh: make(chan struct{}, 1), waitCh: make(chan struct{}), tundev: tstun.WrapTUN(logf, conf.TUN), - resolver: tsdns.NewResolver(logf, magicDNSDomain), + resolver: tsdns.NewResolver(rconf), pingers: make(map[wgcfg.Key]*pinger), } e.localAddrs.Store(map[packet.IP]bool{}) @@ -849,11 +854,16 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config) if routerChanged { if routerCfg.DNS.Proxied { ips := routerCfg.DNS.Nameservers - nameservers := make([]string, len(ips)) + upstreams := make([]net.Addr, len(ips)) for i, ip := range ips { - nameservers[i] = net.JoinHostPort(ip.String(), "53") + stdIP := ip.IPAddr() + upstreams[i] = &net.UDPAddr{ + IP: stdIP.IP, + Port: 53, + Zone: stdIP.Zone, + } } - e.resolver.SetNameservers(nameservers) + e.resolver.SetUpstreams(upstreams) routerCfg.DNS.Nameservers = []netaddr.IP{tsaddr.TailscaleServiceIP()} } e.logf("wgengine: Reconfig: configuring router")