diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index 34ba112f9..f3d5d36d7 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -3,6 +3,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep W 💣 github.com/alexbrainman/sspi from github.com/alexbrainman/sspi/negotiate+ W github.com/alexbrainman/sspi/internal/common from github.com/alexbrainman/sspi/negotiate W 💣 github.com/alexbrainman/sspi/negotiate from tailscale.com/net/tshttpproxy + github.com/golang/groupcache/lru from tailscale.com/net/dnscache github.com/kballard/go-shellquote from tailscale.com/cmd/tailscale/cli L github.com/klauspost/compress/flate from nhooyr.io/websocket 💣 github.com/mitchellh/go-ps from tailscale.com/cmd/tailscale/cli+ @@ -91,7 +92,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box golang.org/x/crypto/poly1305 from golang.org/x/crypto/chacha20poly1305 golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ - golang.org/x/net/dns/dnsmessage from net + golang.org/x/net/dns/dnsmessage from net+ golang.org/x/net/http/httpguts from net/http+ golang.org/x/net/http/httpproxy from net/http golang.org/x/net/http2/hpack from net/http diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 5b72d243a..d0951910a 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -63,6 +63,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de W 💣 github.com/go-ole/go-ole from github.com/go-ole/go-ole/oleutil+ W 💣 github.com/go-ole/go-ole/oleutil from tailscale.com/wgengine/winnet L 💣 github.com/godbus/dbus/v5 from tailscale.com/net/dns + github.com/golang/groupcache/lru from tailscale.com/net/dnscache github.com/google/btree from inet.af/netstack/tcpip/header+ L github.com/insomniacslk/dhcp/dhcpv4 from tailscale.com/net/tstun L github.com/insomniacslk/dhcp/iana from github.com/insomniacslk/dhcp/dhcpv4 diff --git a/go.mod b/go.mod index fe623756a..1632f98e1 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/gliderlabs/ssh v0.3.3 github.com/go-ole/go-ole v1.2.6 github.com/godbus/dbus/v5 v5.0.6 + github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da github.com/google/go-cmp v0.5.6 github.com/google/uuid v1.3.0 github.com/goreleaser/nfpm v1.10.3 diff --git a/go.sum b/go.sum index 34e7f9ee6..1b920c692 100644 --- a/go.sum +++ b/go.sum @@ -387,6 +387,8 @@ github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4er github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= diff --git a/net/dnscache/messagecache.go b/net/dnscache/messagecache.go new file mode 100644 index 000000000..cfdfc8a0f --- /dev/null +++ b/net/dnscache/messagecache.go @@ -0,0 +1,314 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dnscache + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "sync" + "time" + + "github.com/golang/groupcache/lru" + "golang.org/x/net/dns/dnsmessage" +) + +// MessageCache is a cache that works at the DNS message layer, +// with its cache keyed on a DNS wire-level question, and capable +// of replying to DNS messages. +// +// Its zero value is ready for use with a default cache size. +// Use SetMaxCacheSize to specify the cache size. +// +// It's safe for concurrent use. +type MessageCache struct { + // Clock is a clock, for testing. + // If nil, time.Now is used. + Clock func() time.Time + + mu sync.Mutex + cacheSizeSet int // 0 means default + cache lru.Cache // msgQ => *msgCacheValue +} + +func (c *MessageCache) now() time.Time { + if c.Clock != nil { + return c.Clock() + } + return time.Now() +} + +// SetMaxCacheSize sets the maximum number of DNS cache entries that +// can be stored. +func (c *MessageCache) SetMaxCacheSize(n int) { + c.mu.Lock() + defer c.mu.Unlock() + c.cacheSizeSet = n + c.pruneLocked() +} + +// Flush clears the cache. +func (c *MessageCache) Flush() { + c.mu.Lock() + defer c.mu.Unlock() + c.cache.Clear() +} + +// pruneLocked prunes down the cache size to the configured (or +// default) max size. +func (c *MessageCache) pruneLocked() { + max := c.cacheSizeSet + if max == 0 { + max = 500 + } + for c.cache.Len() > max { + c.cache.RemoveOldest() + } +} + +// msgQ is the MessageCache cache key. +// +// It's basically a golang.org/x/net/dns/dnsmessage#Question but the +// Class is omitted (we only cache ClassINET) and we store a Go string +// instead of a 256 byte dnsmessage.Name array. +type msgQ struct { + Name string + Type dnsmessage.Type // A, AAAA, MX, etc +} + +// A *msgCacheValue is the cached value for a msgQ (question) key. +// +// Despite using pointers for storage and methods, the value is +// immutable once placed in the cache. +type msgCacheValue struct { + Expires time.Time + + // Answers are the minimum data to reconstruct a DNS response + // message. TTLs are added later when converting to a + // dnsmessage.Resource. + Answers []msgResource +} + +type msgResource struct { + Name string + Type dnsmessage.Type // dnsmessage.UnknownResource.Type + Data []byte // dnsmessage.UnknownResource.Data +} + +// ErrCacheMiss is a sentinel error returned by MessageCache.ReplyFromCache +// when the request can not be satisified from cache. +var ErrCacheMiss = errors.New("cache miss") + +var parserPool = &sync.Pool{ + New: func() interface{} { return new(dnsmessage.Parser) }, +} + +// ReplyFromCache writes a DNS reply to w for the provided DNS query message, +// which must begin with the two ID bytes of a DNS message. +// +// If there's a cache miss, the message is invalid or unexpected, +// ErrCacheMiss is returned. On cache hit, either nil or an error from +// a w.Write call is returned. +func (c *MessageCache) ReplyFromCache(w io.Writer, dnsQueryMessage []byte) error { + cacheKey, txID, ok := getDNSQueryCacheKey(dnsQueryMessage) + if !ok { + return ErrCacheMiss + } + now := c.now() + + c.mu.Lock() + cacheEntI, _ := c.cache.Get(cacheKey) + v, ok := cacheEntI.(*msgCacheValue) + if ok && now.After(v.Expires) { + c.cache.Remove(cacheKey) + ok = false + } + c.mu.Unlock() + + if !ok { + return ErrCacheMiss + } + + ttl := uint32(v.Expires.Sub(now).Seconds()) + + packedRes, err := packDNSResponse(cacheKey, txID, ttl, v.Answers) + if err != nil { + return ErrCacheMiss + } + _, err = w.Write(packedRes) + return err +} + +var ( + errNotCacheable = errors.New("question not cacheable") +) + +// AddCacheEntry adds a cache entry to the cache. +// It returns an error if the entry could not be cached. +func (c *MessageCache) AddCacheEntry(qPacket, res []byte) error { + cacheKey, qID, ok := getDNSQueryCacheKey(qPacket) + if !ok { + return errNotCacheable + } + now := c.now() + v := &msgCacheValue{} + + p := parserPool.Get().(*dnsmessage.Parser) + defer parserPool.Put(p) + + resh, err := p.Start(res) + if err != nil { + return fmt.Errorf("reading header in response: %w", err) + } + if resh.ID != qID { + return fmt.Errorf("response ID doesn't match query ID") + } + q, err := p.Question() + if err != nil { + return fmt.Errorf("reading 1st question in response: %w", err) + } + if _, err := p.Question(); err != dnsmessage.ErrSectionDone { + if err == nil { + return errors.New("unexpected 2nd question in response") + } + return fmt.Errorf("after reading 1st question in response: %w", err) + } + if resName := asciiLowerName(q.Name).String(); resName != cacheKey.Name { + return fmt.Errorf("response question name %q != question name %q", resName, cacheKey.Name) + } + for { + rh, err := p.AnswerHeader() + if err == dnsmessage.ErrSectionDone { + break + } + if err != nil { + return fmt.Errorf("reading answer: %w", err) + } + res, err := p.UnknownResource() + if err != nil { + return fmt.Errorf("reading resource: %w", err) + } + if rh.Class != dnsmessage.ClassINET { + continue + } + + // Set the cache entry's expiration to the soonest + // we've seen. (They should all be the same, though) + expires := now.Add(time.Duration(rh.TTL) * time.Second) + if v.Expires.IsZero() || expires.Before(v.Expires) { + v.Expires = expires + } + v.Answers = append(v.Answers, msgResource{ + Name: rh.Name.String(), + Type: rh.Type, + Data: res.Data, // doesn't alias; a copy from dnsmessage.unpackUnknownResource + }) + } + c.addCacheValue(cacheKey, v) + return nil +} + +func (c *MessageCache) addCacheValue(cacheKey msgQ, v *msgCacheValue) { + c.mu.Lock() + defer c.mu.Unlock() + c.cache.Add(cacheKey, v) + c.pruneLocked() +} + +func getDNSQueryCacheKey(msg []byte) (cacheKey msgQ, txID uint16, ok bool) { + p := parserPool.Get().(*dnsmessage.Parser) + defer parserPool.Put(p) + h, err := p.Start(msg) + const dnsHeaderSize = 12 + if err != nil || h.OpCode != 0 || h.Response || h.Truncated || + len(msg) < dnsHeaderSize { // p.Start checks this anyway, but to be explicit for slicing below + return cacheKey, 0, false + } + var ( + numQ = binary.BigEndian.Uint16(msg[4:6]) + numAns = binary.BigEndian.Uint16(msg[6:8]) + numAuth = binary.BigEndian.Uint16(msg[8:10]) + numAddn = binary.BigEndian.Uint16(msg[10:12]) + ) + _ = numAddn // ignore this for now; do client OSes send EDNS additional? assume so, ignore. + if !(numQ == 1 && numAns == 0 && numAuth == 0) { + // Something weird. We don't want to deal with it. + return cacheKey, 0, false + } + q, err := p.Question() + if err != nil { + // Already verified numQ == 1 so shouldn't happen, but: + return cacheKey, 0, false + } + if q.Class != dnsmessage.ClassINET { + // We only cache the Internet class. + return cacheKey, 0, false + } + return msgQ{Name: asciiLowerName(q.Name).String(), Type: q.Type}, h.ID, true +} + +func asciiLowerName(n dnsmessage.Name) dnsmessage.Name { + nb := n.Data[:] + if int(n.Length) < len(n.Data) { + nb = nb[:n.Length] + } + for i, b := range nb { + if 'A' <= b && b <= 'Z' { + n.Data[i] += 0x20 + } + } + return n +} + +// packDNSResponse builds a DNS response for the given question and +// transaction ID. The response resource records will have have the +// same provided TTL. +func packDNSResponse(q msgQ, txID uint16, ttl uint32, answers []msgResource) ([]byte, error) { + var baseMem []byte // TODO: guess a max size based on looping over answers? + b := dnsmessage.NewBuilder(baseMem, dnsmessage.Header{ + ID: txID, + Response: true, + OpCode: 0, + Authoritative: false, + Truncated: false, + RCode: dnsmessage.RCodeSuccess, + }) + name, err := dnsmessage.NewName(q.Name) + if err != nil { + return nil, err + } + if err := b.StartQuestions(); err != nil { + return nil, err + } + if err := b.Question(dnsmessage.Question{ + Name: name, + Type: q.Type, + Class: dnsmessage.ClassINET, + }); err != nil { + return nil, err + } + if err := b.StartAnswers(); err != nil { + return nil, err + } + for _, r := range answers { + name, err := dnsmessage.NewName(r.Name) + if err != nil { + return nil, err + } + if err := b.UnknownResource(dnsmessage.ResourceHeader{ + Name: name, + Type: r.Type, + Class: dnsmessage.ClassINET, + TTL: ttl, + }, dnsmessage.UnknownResource{ + Type: r.Type, + Data: r.Data, + }); err != nil { + return nil, err + } + } + return b.Finish() +} diff --git a/net/dnscache/messagecache_test.go b/net/dnscache/messagecache_test.go new file mode 100644 index 000000000..be56aef6e --- /dev/null +++ b/net/dnscache/messagecache_test.go @@ -0,0 +1,292 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dnscache + +import ( + "bytes" + "context" + "errors" + "fmt" + "net" + "runtime" + "testing" + "time" + + "golang.org/x/net/dns/dnsmessage" + "tailscale.com/tstest" +) + +func TestMessageCache(t *testing.T) { + clock := &tstest.Clock{ + Start: time.Date(1987, 11, 1, 0, 0, 0, 0, time.UTC), + } + mc := &MessageCache{Clock: clock.Now} + mc.SetMaxCacheSize(2) + clock.Advance(time.Second) + + var out bytes.Buffer + if err := mc.ReplyFromCache(&out, makeQ(1, "foo.com.")); err != ErrCacheMiss { + t.Fatalf("unexpected error: %v", err) + } + + if err := mc.AddCacheEntry( + makeQ(2, "foo.com."), + makeRes(2, "FOO.COM.", ttlOpt(10), + &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}}, + &dnsmessage.AResource{A: [4]byte{127, 0, 0, 2}})); err != nil { + t.Fatal(err) + } + + // Expect cache hit, with 10 seconds remaining. + out.Reset() + if err := mc.ReplyFromCache(&out, makeQ(3, "foo.com.")); err != nil { + t.Fatalf("expected cache hit; got: %v", err) + } + if p := mustParseResponse(t, out.Bytes()); p.TxID != 3 { + t.Errorf("TxID = %v; want %v", p.TxID, 3) + } else if p.TTL != 10 { + t.Errorf("TTL = %v; want 10", p.TTL) + } + + // One second elapses, expect a cache hit, with 9 seconds + // remaining. + clock.Advance(time.Second) + out.Reset() + if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.")); err != nil { + t.Fatalf("expected cache hit; got: %v", err) + } + if p := mustParseResponse(t, out.Bytes()); p.TxID != 4 { + t.Errorf("TxID = %v; want %v", p.TxID, 4) + } else if p.TTL != 9 { + t.Errorf("TTL = %v; want 9", p.TTL) + } + + // Expect cache miss on MX record. + if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.TypeMX)); err != ErrCacheMiss { + t.Fatalf("expected cache miss on MX; got: %v", err) + } + // Expect cache miss on CHAOS class. + if err := mc.ReplyFromCache(&out, makeQ(4, "foo.com.", dnsmessage.ClassCHAOS)); err != ErrCacheMiss { + t.Fatalf("expected cache miss on CHAOS; got: %v", err) + } + + // Ten seconds elapses; expect a cache miss. + clock.Advance(10 * time.Second) + if err := mc.ReplyFromCache(&out, makeQ(5, "foo.com.")); err != ErrCacheMiss { + t.Fatalf("expected cache miss, got: %v", err) + } +} + +type parsedMeta struct { + TxID uint16 + TTL uint32 +} + +func mustParseResponse(t testing.TB, r []byte) (ret parsedMeta) { + t.Helper() + var p dnsmessage.Parser + h, err := p.Start(r) + if err != nil { + t.Fatal(err) + } + ret.TxID = h.ID + qq, err := p.AllQuestions() + if err != nil { + t.Fatalf("AllQuestions: %v", err) + } + if len(qq) != 1 { + t.Fatalf("num questions = %v; want 1", len(qq)) + } + aa, err := p.AllAnswers() + if err != nil { + t.Fatalf("AllAnswers: %v", err) + } + for _, r := range aa { + if ret.TTL == 0 { + ret.TTL = r.Header.TTL + } + if ret.TTL != r.Header.TTL { + t.Fatal("mixed TTLs") + } + } + return ret +} + +type responseOpt bool + +type ttlOpt uint32 + +func makeQ(txID uint16, name string, opt ...interface{}) []byte { + opt = append(opt, responseOpt(false)) + return makeDNSPkt(txID, name, opt...) +} + +func makeRes(txID uint16, name string, opt ...interface{}) []byte { + opt = append(opt, responseOpt(true)) + return makeDNSPkt(txID, name, opt...) +} + +func makeDNSPkt(txID uint16, name string, opt ...interface{}) []byte { + typ := dnsmessage.TypeA + class := dnsmessage.ClassINET + var response bool + var answers []dnsmessage.ResourceBody + var ttl uint32 = 1 // one second by default + for _, o := range opt { + switch o := o.(type) { + case dnsmessage.Type: + typ = o + case dnsmessage.Class: + class = o + case responseOpt: + response = bool(o) + case dnsmessage.ResourceBody: + answers = append(answers, o) + case ttlOpt: + ttl = uint32(o) + default: + panic(fmt.Sprintf("unknown opt type %T", o)) + } + } + qname := dnsmessage.MustNewName(name) + msg := dnsmessage.Message{ + Header: dnsmessage.Header{ID: txID, Response: response}, + Questions: []dnsmessage.Question{ + { + Name: qname, + Type: typ, + Class: class, + }, + }, + } + for _, rb := range answers { + msg.Answers = append(msg.Answers, dnsmessage.Resource{ + Header: dnsmessage.ResourceHeader{ + Name: qname, + Type: typ, + Class: class, + TTL: ttl, + }, + Body: rb, + }) + } + buf, err := msg.Pack() + if err != nil { + panic(err) + } + return buf +} + +func TestASCIILowerName(t *testing.T) { + n := asciiLowerName(dnsmessage.MustNewName("Foo.COM.")) + if got, want := n.String(), "foo.com."; got != want { + t.Errorf("got = %q; want %q", got, want) + } +} + +func TestGetDNSQueryCacheKey(t *testing.T) { + tests := []struct { + name string + pkt []byte + want msgQ + txID uint16 + anyTX bool + }{ + { + name: "empty", + }, + { + name: "a", + pkt: makeQ(123, "foo.com."), + want: msgQ{"foo.com.", dnsmessage.TypeA}, + txID: 123, + }, + { + name: "aaaa", + pkt: makeQ(6, "foo.com.", dnsmessage.TypeAAAA), + want: msgQ{"foo.com.", dnsmessage.TypeAAAA}, + txID: 6, + }, + { + name: "normalize_case", + pkt: makeQ(123, "FoO.CoM."), + want: msgQ{"foo.com.", dnsmessage.TypeA}, + txID: 123, + }, + { + name: "ignore_response", + pkt: makeRes(123, "foo.com."), + }, + { + name: "ignore_question_with_answers", + pkt: makeQ(2, "foo.com.", &dnsmessage.AResource{A: [4]byte{127, 0, 0, 1}}), + }, + { + name: "whatever_go_generates", // in case Go's net package grows functionality we don't handle + pkt: getGoNetPacketDNSQuery("from-go.foo."), + want: msgQ{"from-go.foo.", dnsmessage.TypeA}, + anyTX: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, gotTX, ok := getDNSQueryCacheKey(tt.pkt) + if !ok { + if tt.txID == 0 && got == (msgQ{}) { + return + } + t.Fatal("failed") + } + if got != tt.want { + t.Errorf("got %+v, want %+v", got, tt.want) + } + if gotTX != tt.txID && !tt.anyTX { + t.Errorf("got tx %v, want %v", gotTX, tt.txID) + } + }) + } +} + +func getGoNetPacketDNSQuery(name string) []byte { + if runtime.GOOS == "windows" { + // On Windows, Go's net.Resolver doesn't use the DNS client. + // See https://github.com/golang/go/issues/33097 which + // was approved but not yet implemented. + // For now just pretend it's implemented to make this test + // pass on Windows with complicated the caller. + return makeQ(123, name) + } + res := make(chan []byte, 1) + r := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + return goResolverConn(res), nil + }, + } + r.LookupIP(context.Background(), "ip4", name) + return <-res +} + +type goResolverConn chan<- []byte + +func (goResolverConn) Close() error { return nil } +func (goResolverConn) LocalAddr() net.Addr { return todoAddr{} } +func (goResolverConn) RemoteAddr() net.Addr { return todoAddr{} } +func (goResolverConn) SetDeadline(t time.Time) error { return nil } +func (goResolverConn) SetReadDeadline(t time.Time) error { return nil } +func (goResolverConn) SetWriteDeadline(t time.Time) error { return nil } +func (goResolverConn) Read([]byte) (int, error) { return 0, errors.New("boom") } +func (c goResolverConn) Write(p []byte) (int, error) { + select { + case c <- p[2:]: // skip 2 byte length for TCP mode DNS query + default: + } + return 0, errors.New("boom") +} + +type todoAddr struct{} + +func (todoAddr) Network() string { return "unused" } +func (todoAddr) String() string { return "unused-todoAddr" } diff --git a/net/tsdial/dohclient.go b/net/tsdial/dohclient.go index 3ff11d4a9..352fdb4cd 100644 --- a/net/tsdial/dohclient.go +++ b/net/tsdial/dohclient.go @@ -13,15 +13,18 @@ "net" "net/http" "time" + + "tailscale.com/net/dnscache" ) // dohConn is a net.PacketConn suitable for returning from // net.Dialer.Dial to send DNS queries over PeerAPI to exit nodes' // ExitDNS DoH proxy service. type dohConn struct { - ctx context.Context - baseURL string - hc *http.Client // if nil, default is used + ctx context.Context + baseURL string + hc *http.Client // if nil, default is used + dnsCache *dnscache.MessageCache rbuf bytes.Buffer } @@ -52,6 +55,15 @@ func (c *dohConn) Read(p []byte) (n int, err error) { } func (c *dohConn) Write(packet []byte) (n int, err error) { + if c.dnsCache != nil { + err := c.dnsCache.ReplyFromCache(&c.rbuf, packet) + if err == nil { + // Cache hit. + // TODO(bradfitz): add clientmetric + return len(packet), nil + } + c.rbuf.Reset() + } req, err := http.NewRequestWithContext(c.ctx, "POST", c.baseURL, bytes.NewReader(packet)) if err != nil { return 0, err @@ -77,6 +89,9 @@ func (c *dohConn) Write(packet []byte) (n int, err error) { if err != nil { return 0, err } + if c.dnsCache != nil { + c.dnsCache.AddCacheEntry(packet, c.rbuf.Bytes()) + } return len(packet), nil } diff --git a/net/tsdial/tsdial.go b/net/tsdial/tsdial.go index 6e9d237b8..961567232 100644 --- a/net/tsdial/tsdial.go +++ b/net/tsdial/tsdial.go @@ -11,6 +11,7 @@ "fmt" "net" "net/http" + "runtime" "strings" "sync" "sync/atomic" @@ -18,6 +19,7 @@ "time" "inet.af/netaddr" + "tailscale.com/net/dnscache" "tailscale.com/net/netknob" "tailscale.com/types/netmap" "tailscale.com/wgengine/monitor" @@ -48,7 +50,8 @@ type Dialer struct { dns dnsMap tunName string // tun device name linkMon *monitor.Mon - exitDNSDoHBase string // non-empty if DoH-proxying exit node in use; base URL+path (without '?') + exitDNSDoHBase string // non-empty if DoH-proxying exit node in use; base URL+path (without '?') + dnsCache *dnscache.MessageCache // nil until first first non-empty SetExitDNSDoH } // SetTUNName sets the name of the tun device in use ("tailscale0", "utun6", @@ -76,7 +79,16 @@ func (d *Dialer) TUNName() string { func (d *Dialer) SetExitDNSDoH(doh string) { d.mu.Lock() defer d.mu.Unlock() + if d.exitDNSDoHBase == doh { + return + } d.exitDNSDoHBase = doh + if doh != "" && d.dnsCache == nil { + d.dnsCache = new(dnscache.MessageCache) + } + if d.dnsCache != nil { + d.dnsCache.Flush() + } } func (d *Dialer) SetLinkMonitor(mon *monitor.Mon) { @@ -149,12 +161,14 @@ func (d *Dialer) userDialResolve(ctx context.Context, network, addr string) (net } var r net.Resolver - if exitDNSDoH != "" { + if exitDNSDoH != "" && runtime.GOOS != "windows" { // Windows: https://github.com/golang/go/issues/33097 + r.PreferGo = true r.Dial = func(ctx context.Context, network, address string) (net.Conn, error) { return &dohConn{ - ctx: ctx, - baseURL: exitDNSDoH, - hc: d.PeerAPIHTTPClient(), + ctx: ctx, + baseURL: exitDNSDoH, + hc: d.PeerAPIHTTPClient(), + dnsCache: d.dnsCache, }, nil } }