net/dns, wgengine: implement DNS over TCP (#4598)

* net/dns, wgengine: implement DNS over TCP

Signed-off-by: Tom DNetto <tom@tailscale.com>

* wgengine/netstack: intercept only relevant port/protocols to quad-100

Signed-off-by: Tom DNetto <tom@tailscale.com>
This commit is contained in:
Tom 2022-05-05 16:42:45 -07:00 committed by GitHub
parent c4f06ef7be
commit d1d6ab068e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 273 additions and 0 deletions

View File

@ -7,7 +7,9 @@
import (
"bufio"
"context"
"encoding/binary"
"errors"
"io"
"net"
"runtime"
"sync/atomic"
@ -346,6 +348,123 @@ func (m *Manager) Query(ctx context.Context, bs []byte, from netaddr.IPPort) ([]
return m.resolver.Query(ctx, bs, from)
}
const (
// RFC 7766 6.2 recommends connection reuse & request pipelining
// be undertaken, and the connection be closed by the server
// using an idle timeout on the order of seconds.
idleTimeoutTCP = 45 * time.Second
// The RFCs don't specify the max size of a TCP-based DNS query,
// but we want to keep this reasonable. Given payloads are typically
// much larger and all known client send a single query, I've arbitrarily
// chosen 2k.
maxReqSizeTCP = 2048
)
// dnsTCPSession services DNS requests sent over TCP.
type dnsTCPSession struct {
m *Manager
conn net.Conn
srcAddr netaddr.IPPort
readClosing chan struct{}
responses chan []byte // DNS replies pending writing
ctx context.Context
closeCtx context.CancelFunc
}
func (s *dnsTCPSession) handleWrites() {
defer s.conn.Close()
defer close(s.responses)
defer s.closeCtx()
for {
select {
case <-s.readClosing:
return // connection closed or timeout, teardown time
case resp := <-s.responses:
s.conn.SetWriteDeadline(time.Now().Add(idleTimeoutTCP))
if err := binary.Write(s.conn, binary.BigEndian, uint16(len(resp))); err != nil {
s.m.logf("tcp write (len): %v", err)
return
}
if _, err := s.conn.Write(resp); err != nil {
s.m.logf("tcp write (response): %v", err)
return
}
}
}
}
func (s *dnsTCPSession) handleQuery(q []byte) {
resp, err := s.m.Query(s.ctx, q, s.srcAddr)
if err != nil {
s.m.logf("tcp query: %v", err)
return
}
select {
case <-s.ctx.Done():
case s.responses <- resp:
}
}
func (s *dnsTCPSession) handleReads() {
defer close(s.readClosing)
for {
select {
case <-s.ctx.Done():
return
default:
s.conn.SetReadDeadline(time.Now().Add(idleTimeoutTCP))
var reqLen uint16
if err := binary.Read(s.conn, binary.BigEndian, &reqLen); err != nil {
if err == io.EOF || err == io.ErrClosedPipe {
return // connection closed nominally, we gucci
}
s.m.logf("tcp read (len): %v", err)
return
}
if int(reqLen) > maxReqSizeTCP {
s.m.logf("tcp request too large (%d > %d)", reqLen, maxReqSizeTCP)
return
}
buf := make([]byte, int(reqLen))
if _, err := io.ReadFull(s.conn, buf); err != nil {
s.m.logf("tcp read (payload): %v", err)
return
}
select {
case <-s.ctx.Done():
return
default:
go s.handleQuery(buf)
}
}
}
}
// HandleTCPConn implements magicDNS over TCP, taking a connection and
// servicing DNS requests sent down it.
func (m *Manager) HandleTCPConn(conn net.Conn, srcAddr netaddr.IPPort) {
s := dnsTCPSession{
m: m,
conn: conn,
srcAddr: srcAddr,
responses: make(chan []byte),
readClosing: make(chan struct{}),
}
s.ctx, s.closeCtx = context.WithCancel(context.Background())
go s.handleReads()
s.handleWrites()
}
func (m *Manager) Down() error {
m.ctxCancel()
if err := m.os.Close(); err != nil {

136
net/dns/manager_tcp_test.go Normal file
View File

@ -0,0 +1,136 @@
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dns
import (
"encoding/binary"
"io"
"net"
"testing"
"github.com/google/go-cmp/cmp"
dns "golang.org/x/net/dns/dnsmessage"
"inet.af/netaddr"
"tailscale.com/net/tsdial"
"tailscale.com/util/dnsname"
)
func mkDNSRequest(domain dnsname.FQDN, tp dns.Type) []byte {
var dnsHeader dns.Header
question := dns.Question{
Name: dns.MustNewName(domain.WithTrailingDot()),
Type: tp,
Class: dns.ClassINET,
}
builder := dns.NewBuilder(nil, dnsHeader)
if err := builder.StartQuestions(); err != nil {
panic(err)
}
if err := builder.Question(question); err != nil {
panic(err)
}
if err := builder.StartAdditionals(); err != nil {
panic(err)
}
ednsHeader := dns.ResourceHeader{
Name: dns.MustNewName("."),
Type: dns.TypeOPT,
Class: dns.Class(4095),
}
if err := builder.OPTResource(ednsHeader, dns.OPTResource{}); err != nil {
panic(err)
}
payload, _ := builder.Finish()
return payload
}
func TestDNSOverTCP(t *testing.T) {
f := fakeOSConfigurator{
SplitDNS: true,
BaseConfig: OSConfig{
Nameservers: mustIPs("8.8.8.8"),
SearchDomains: fqdns("coffee.shop"),
},
}
m := NewManager(t.Logf, &f, nil, new(tsdial.Dialer), nil)
m.resolver.TestOnlySetHook(f.SetResolver)
m.Set(Config{
Hosts: hosts(
"dave.ts.com.", "1.2.3.4",
"bradfitz.ts.com.", "2.3.4.5"),
Routes: upstreams("ts.com", ""),
SearchDomains: fqdns("tailscale.com", "universe.tf"),
})
defer m.Down()
c, s := net.Pipe()
defer s.Close()
go m.HandleTCPConn(s, netaddr.IPPort{})
defer c.Close()
wantResults := map[dnsname.FQDN]string{
"dave.ts.com.": "1.2.3.4",
"bradfitz.ts.com.": "2.3.4.5",
}
for domain, _ := range wantResults {
b := mkDNSRequest(domain, dns.TypeA)
binary.Write(c, binary.BigEndian, uint16(len(b)))
c.Write(b)
}
results := map[dnsname.FQDN]string{}
for i := 0; i < len(wantResults); i++ {
var respLength uint16
if err := binary.Read(c, binary.BigEndian, &respLength); err != nil {
t.Fatalf("reading len: %v", err)
}
resp := make([]byte, int(respLength))
if _, err := io.ReadFull(c, resp); err != nil {
t.Fatalf("reading data: %v", err)
}
var parser dns.Parser
if _, err := parser.Start(resp); err != nil {
t.Errorf("parser.Start() failed: %v", err)
continue
}
q, err := parser.Question()
if err != nil {
t.Errorf("parser.Question(): %v", err)
continue
}
if err := parser.SkipAllQuestions(); err != nil {
t.Errorf("parser.SkipAllQuestions(): %v", err)
continue
}
ah, err := parser.AnswerHeader()
if err != nil {
t.Errorf("parser.AnswerHeader(): %v", err)
continue
}
if ah.Type != dns.TypeA {
t.Errorf("unexpected answer type: got %v, want %v", ah.Type, dns.TypeA)
continue
}
res, err := parser.AResource()
if err != nil {
t.Errorf("parser.AResource(): %v", err)
continue
}
results[dnsname.FQDN(q.Name.String())] = net.IP(res.A[:]).String()
}
c.Close()
if diff := cmp.Diff(wantResults, results); diff != "" {
t.Errorf("wrong results (-got+want)\n%s", diff)
}
}

View File

@ -373,6 +373,19 @@ func (ns *Impl) handleLocalPackets(p *packet.Parsed, t *tstun.Wrapper) filter.Re
if dst := p.Dst.IP(); dst != magicDNSIP && dst != magicDNSIPv6 {
return filter.Accept
}
// Of traffic to the service IP, we only care about UDP 53, and TCP
// on port 80 & 53.
switch p.IPProto {
case ipproto.TCP:
if port := p.Dst.Port(); port != 53 && port != 80 {
return filter.Accept
}
case ipproto.UDP:
if port := p.Dst.Port(); port != 53 {
return filter.Accept
}
}
var pn tcpip.NetworkProtocolNumber
switch p.IPVersion {
@ -758,6 +771,11 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
// block until the TCP handshake is complete.
c := gonet.NewTCPConn(&wq, ep)
if reqDetails.LocalPort == 53 && (dialIP == magicDNSIP || dialIP == magicDNSIPv6) {
go ns.dns.HandleTCPConn(c, netaddr.IPPortFrom(clientRemoteIP, reqDetails.RemotePort))
return
}
if ns.lb != nil {
if reqDetails.LocalPort == 22 && ns.processSSH() && ns.isLocalIP(dialIP) {
if err := ns.lb.HandleSSHConn(c); err != nil {