cmd/natc: fix handling of upstream and downstream nxdomain

Ensure that the upstream is always queried, so that if upstream is going
to NXDOMAIN natc will also return NXDOMAIN rather than returning address
allocations.

At this time both IPv4 and IPv6 are still returned if upstream has a
result, regardless of upstream support - this is ~ok as we're proxying.

Rewrite the tests to be once again slightly closer to integration tests,
but they're still very rough and in need of a refactor.

Further refactors are probably needed implementation side too, as this
removed rather than added units.

Updates #15367

Signed-off-by: James Tucker <james@tailscale.com>
This commit is contained in:
James Tucker 2025-04-01 18:52:45 -07:00 committed by James Tucker
parent fb96137d79
commit 025fe72448
2 changed files with 379 additions and 252 deletions

View File

@ -26,14 +26,15 @@ import (
"go4.org/netipx" "go4.org/netipx"
"golang.org/x/net/dns/dnsmessage" "golang.org/x/net/dns/dnsmessage"
"tailscale.com/client/local" "tailscale.com/client/local"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/cmd/natc/ippool" "tailscale.com/cmd/natc/ippool"
"tailscale.com/envknob" "tailscale.com/envknob"
"tailscale.com/hostinfo" "tailscale.com/hostinfo"
"tailscale.com/ipn" "tailscale.com/ipn"
"tailscale.com/net/netutil" "tailscale.com/net/netutil"
"tailscale.com/tailcfg"
"tailscale.com/tsnet" "tailscale.com/tsnet"
"tailscale.com/tsweb" "tailscale.com/tsweb"
"tailscale.com/util/mak"
"tailscale.com/util/must" "tailscale.com/util/must"
"tailscale.com/wgengine/netstack" "tailscale.com/wgengine/netstack"
) )
@ -148,14 +149,15 @@ func main() {
v6ULA := ula(uint16(*siteID)) v6ULA := ula(uint16(*siteID))
c := &connector{ c := &connector{
ts: ts, ts: ts,
lc: lc, whois: lc,
v6ULA: v6ULA, v6ULA: v6ULA,
ignoreDsts: ignoreDstTable, ignoreDsts: ignoreDstTable,
ipPool: &ippool.IPPool{V6ULA: v6ULA, IPSet: addrPool}, ipPool: &ippool.IPPool{V6ULA: v6ULA, IPSet: addrPool},
routes: routes, routes: routes,
dnsAddr: dnsAddr, dnsAddr: dnsAddr,
resolver: net.DefaultResolver,
} }
c.run(ctx) c.run(ctx, lc)
} }
func calculateAddresses(prefixes []netip.Prefix) (*netipx.IPSet, netip.Addr, *netipx.IPSet) { func calculateAddresses(prefixes []netip.Prefix) (*netipx.IPSet, netip.Addr, *netipx.IPSet) {
@ -170,12 +172,20 @@ func calculateAddresses(prefixes []netip.Prefix) (*netipx.IPSet, netip.Addr, *ne
return routesToAdvertise, dnsAddr, addrPool return routesToAdvertise, dnsAddr, addrPool
} }
type lookupNetIPer interface {
LookupNetIP(ctx context.Context, net, host string) ([]netip.Addr, error)
}
type whoiser interface {
WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error)
}
type connector struct { type connector struct {
// ts is the tsnet.Server used to host the connector. // ts is the tsnet.Server used to host the connector.
ts *tsnet.Server ts *tsnet.Server
// lc is the local.Client used to interact with the tsnet.Server hosting this // whois is the local.Client used to interact with the tsnet.Server hosting this
// connector. // connector.
lc *local.Client whois whoiser
// dnsAddr is the IPv4 address to listen on for DNS requests. It is used to // dnsAddr is the IPv4 address to listen on for DNS requests. It is used to
// prevent the app connector from assigning it to a domain. // prevent the app connector from assigning it to a domain.
@ -197,7 +207,11 @@ type connector struct {
// natc behavior, which would return a dummy ip address pointing at natc). // natc behavior, which would return a dummy ip address pointing at natc).
ignoreDsts *bart.Table[bool] ignoreDsts *bart.Table[bool]
// ipPool contains the per-peer IPv4 address assignments.
ipPool *ippool.IPPool ipPool *ippool.IPPool
// resolver is used to lookup IP addresses for DNS queries.
resolver lookupNetIPer
} }
// v6ULA is the ULA prefix used by the app connector to assign IPv6 addresses. // v6ULA is the ULA prefix used by the app connector to assign IPv6 addresses.
@ -217,8 +231,8 @@ func ula(siteID uint16) netip.Prefix {
// //
// The passed in context is only used for the initial setup. The connector runs // The passed in context is only used for the initial setup. The connector runs
// forever. // forever.
func (c *connector) run(ctx context.Context) { func (c *connector) run(ctx context.Context, lc *local.Client) {
if _, err := c.lc.EditPrefs(ctx, &ipn.MaskedPrefs{ if _, err := lc.EditPrefs(ctx, &ipn.MaskedPrefs{
AdvertiseRoutesSet: true, AdvertiseRoutesSet: true,
Prefs: ipn.Prefs{ Prefs: ipn.Prefs{
AdvertiseRoutes: append(c.routes.Prefixes(), c.v6ULA), AdvertiseRoutes: append(c.routes.Prefixes(), c.v6ULA),
@ -251,26 +265,6 @@ func (c *connector) serveDNS() {
} }
} }
func lookupDestinationIP(domain string) ([]netip.Addr, error) {
netIPs, err := net.LookupIP(domain)
if err != nil {
var dnsError *net.DNSError
if errors.As(err, &dnsError) && dnsError.IsNotFound {
return nil, nil
} else {
return nil, err
}
}
var addrs []netip.Addr
for _, ip := range netIPs {
a, ok := netip.AddrFromSlice(ip)
if ok {
addrs = append(addrs, a)
}
}
return addrs, nil
}
// handleDNS handles a DNS request to the app connector. // handleDNS handles a DNS request to the app connector.
// It generates a response based on the request and the node that sent it. // It generates a response based on the request and the node that sent it.
// //
@ -285,7 +279,7 @@ func lookupDestinationIP(domain string) ([]netip.Addr, error) {
func (c *connector) handleDNS(pc net.PacketConn, buf []byte, remoteAddr *net.UDPAddr) { func (c *connector) handleDNS(pc net.PacketConn, buf []byte, remoteAddr *net.UDPAddr) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
who, err := c.lc.WhoIs(ctx, remoteAddr.String()) who, err := c.whois.WhoIs(ctx, remoteAddr.String())
if err != nil { if err != nil {
log.Printf("HandleDNS(remote=%s): WhoIs failed: %v\n", remoteAddr.String(), err) log.Printf("HandleDNS(remote=%s): WhoIs failed: %v\n", remoteAddr.String(), err)
return return
@ -298,49 +292,122 @@ func (c *connector) handleDNS(pc net.PacketConn, buf []byte, remoteAddr *net.UDP
return return
} }
// If there are destination ips that we don't want to route, we var resolves map[string][]netip.Addr
// have to do a dns lookup here to find the destination ip. var addrQCount int
if c.ignoreDsts != nil { for _, q := range msg.Questions {
if len(msg.Questions) > 0 { if q.Type != dnsmessage.TypeA && q.Type != dnsmessage.TypeAAAA {
q := msg.Questions[0] continue
switch q.Type { }
case dnsmessage.TypeAAAA, dnsmessage.TypeA: addrQCount++
dstAddrs, err := lookupDestinationIP(q.Name.String()) if _, ok := resolves[q.Name.String()]; !ok {
addrs, err := c.resolver.LookupNetIP(ctx, "ip", q.Name.String())
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
continue
}
if err != nil {
log.Printf("HandleDNS(remote=%s): lookup destination failed: %v\n", remoteAddr.String(), err)
return
}
// Note: If _any_ destination is ignored, pass through all of the resolved
// addresses as-is.
//
// This could result in some odd split-routing if there was a mix of
// ignored and non-ignored addresses, but it's currently the user
// preferred behavior.
if !c.ignoreDestination(addrs) {
addrs, err = c.ipPool.IPForDomain(who.Node.ID, q.Name.String())
if err != nil { if err != nil {
log.Printf("HandleDNS(remote=%s): lookup destination failed: %v\n", remoteAddr.String(), err) log.Printf("HandleDNS(remote=%s): lookup destination failed: %v\n", remoteAddr.String(), err)
return return
} }
if c.ignoreDestination(dstAddrs) { }
bs, err := dnsResponse(&msg, dstAddrs) mak.Set(&resolves, q.Name.String(), addrs)
// TODO (fran): treat as SERVFAIL }
if err != nil { }
log.Printf("HandleDNS(remote=%s): generate ignore response failed: %v\n", remoteAddr.String(), err)
return rcode := dnsmessage.RCodeSuccess
} if addrQCount > 0 && len(resolves) == 0 {
_, err = pc.WriteTo(bs, remoteAddr) rcode = dnsmessage.RCodeNameError
if err != nil { }
log.Printf("HandleDNS(remote=%s): write failed: %v\n", remoteAddr.String(), err)
} b := dnsmessage.NewBuilder(nil,
dnsmessage.Header{
ID: msg.Header.ID,
Response: true,
Authoritative: true,
RCode: rcode,
})
b.EnableCompression()
if err := b.StartQuestions(); err != nil {
log.Printf("HandleDNS(remote=%s): dnsmessage start questions failed: %v\n", remoteAddr.String(), err)
return
}
for _, q := range msg.Questions {
b.Question(q)
}
if err := b.StartAnswers(); err != nil {
log.Printf("HandleDNS(remote=%s): dnsmessage start answers failed: %v\n", remoteAddr.String(), err)
return
}
for _, q := range msg.Questions {
switch q.Type {
case dnsmessage.TypeSOA:
if err := b.SOAResource(
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600,
Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60},
); err != nil {
log.Printf("HandleDNS(remote=%s): dnsmessage SOA resource failed: %v\n", remoteAddr.String(), err)
return
}
case dnsmessage.TypeNS:
if err := b.NSResource(
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
dnsmessage.NSResource{NS: tsMBox},
); err != nil {
log.Printf("HandleDNS(remote=%s): dnsmessage NS resource failed: %v\n", remoteAddr.String(), err)
return
}
case dnsmessage.TypeAAAA:
for _, addr := range resolves[q.Name.String()] {
if !addr.Is6() {
continue
}
if err := b.AAAAResource(
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
dnsmessage.AAAAResource{AAAA: addr.As16()},
); err != nil {
log.Printf("HandleDNS(remote=%s): dnsmessage AAAA resource failed: %v\n", remoteAddr.String(), err)
return
}
}
case dnsmessage.TypeA:
for _, addr := range resolves[q.Name.String()] {
if !addr.Is4() {
continue
}
if err := b.AResource(
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
dnsmessage.AResource{A: addr.As4()},
); err != nil {
log.Printf("HandleDNS(remote=%s): dnsmessage A resource failed: %v\n", remoteAddr.String(), err)
return return
} }
} }
} }
} }
// None of the destination IP addresses match an ignore destination prefix, do
// the natc thing.
resp, err := c.generateDNSResponse(&msg, who.Node.ID) out, err := b.Finish()
// TODO (fran): treat as SERVFAIL
if err != nil { if err != nil {
log.Printf("HandleDNS(remote=%s): connector handling failed: %v\n", remoteAddr.String(), err) log.Printf("HandleDNS(remote=%s): dnsmessage finish failed: %v\n", remoteAddr.String(), err)
return return
} }
// TODO (fran): treat as NXDOMAIN _, err = pc.WriteTo(out, remoteAddr)
if len(resp) == 0 {
return
}
// This connector handled the DNS request
_, err = pc.WriteTo(resp, remoteAddr)
if err != nil { if err != nil {
log.Printf("HandleDNS(remote=%s): write failed: %v\n", remoteAddr.String(), err) log.Printf("HandleDNS(remote=%s): write failed: %v\n", remoteAddr.String(), err)
} }
@ -352,89 +419,6 @@ func (c *connector) handleDNS(pc net.PacketConn, buf []byte, remoteAddr *net.UDP
// to indicate that it is a fully qualified domain name. // to indicate that it is a fully qualified domain name.
var tsMBox = dnsmessage.MustNewName("support.tailscale.com.") var tsMBox = dnsmessage.MustNewName("support.tailscale.com.")
// generateDNSResponse generates a DNS response for the given request. The from
// argument is the NodeID of the node that sent the request.
func (c *connector) generateDNSResponse(req *dnsmessage.Message, from tailcfg.NodeID) ([]byte, error) {
var addrs []netip.Addr
if len(req.Questions) > 0 {
switch req.Questions[0].Type {
case dnsmessage.TypeAAAA, dnsmessage.TypeA:
var err error
addrs, err = c.ipPool.IPForDomain(from, req.Questions[0].Name.String())
if err != nil {
return nil, err
}
}
}
return dnsResponse(req, addrs)
}
// dnsResponse makes a DNS response for the natc. If the dnsmessage is requesting TypeAAAA
// or TypeA the provided addrs of the requested type will be used.
func dnsResponse(req *dnsmessage.Message, addrs []netip.Addr) ([]byte, error) {
b := dnsmessage.NewBuilder(nil,
dnsmessage.Header{
ID: req.Header.ID,
Response: true,
Authoritative: true,
})
b.EnableCompression()
if len(req.Questions) == 0 {
return b.Finish()
}
q := req.Questions[0]
if err := b.StartQuestions(); err != nil {
return nil, err
}
if err := b.Question(q); err != nil {
return nil, err
}
if err := b.StartAnswers(); err != nil {
return nil, err
}
switch q.Type {
case dnsmessage.TypeAAAA, dnsmessage.TypeA:
want6 := q.Type == dnsmessage.TypeAAAA
for _, ip := range addrs {
if want6 != ip.Is6() {
continue
}
if want6 {
if err := b.AAAAResource(
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 5},
dnsmessage.AAAAResource{AAAA: ip.As16()},
); err != nil {
return nil, err
}
} else {
if err := b.AResource(
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 5},
dnsmessage.AResource{A: ip.As4()},
); err != nil {
return nil, err
}
}
}
case dnsmessage.TypeSOA:
if err := b.SOAResource(
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600,
Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60},
); err != nil {
return nil, err
}
case dnsmessage.TypeNS:
if err := b.NSResource(
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
dnsmessage.NSResource{NS: tsMBox},
); err != nil {
return nil, err
}
}
return b.Finish()
}
// handleTCPFlow handles a TCP flow from the given source to the given // handleTCPFlow handles a TCP flow from the given source to the given
// destination. It uses the source address to determine the node that sent the // destination. It uses the source address to determine the node that sent the
// request and the destination address to determine the domain that the request // request and the destination address to determine the domain that the request
@ -443,7 +427,7 @@ func dnsResponse(req *dnsmessage.Message, addrs []netip.Addr) ([]byte, error) {
func (c *connector) handleTCPFlow(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) { func (c *connector) handleTCPFlow(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
who, err := c.lc.WhoIs(ctx, src.Addr().String()) who, err := c.whois.WhoIs(ctx, src.Addr().String())
cancel() cancel()
if err != nil { if err != nil {
log.Printf("HandleTCPFlow: WhoIs failed: %v\n", err) log.Printf("HandleTCPFlow: WhoIs failed: %v\n", err)
@ -461,6 +445,9 @@ func (c *connector) handleTCPFlow(src, dst netip.AddrPort) (handler func(net.Con
// ignoreDestination reports whether any of the provided dstAddrs match the prefixes configured // ignoreDestination reports whether any of the provided dstAddrs match the prefixes configured
// in --ignore-destinations // in --ignore-destinations
func (c *connector) ignoreDestination(dstAddrs []netip.Addr) bool { func (c *connector) ignoreDestination(dstAddrs []netip.Addr) bool {
if c.ignoreDsts == nil {
return false
}
for _, a := range dstAddrs { for _, a := range dstAddrs {
if _, ok := c.ignoreDsts.Lookup(a); ok { if _, ok := c.ignoreDsts.Lookup(a); ok {
return true return true
@ -488,6 +475,8 @@ func proxyTCPConn(c net.Conn, dest string) {
return netutil.NewOneConnListener(c, nil), nil return netutil.NewOneConnListener(c, nil), nil
}, },
} }
// XXX(raggi): if the connection here resolves to an ignored destination,
// the connection should be closed/failed.
p.AddRoute(addrPortStr, &tcpproxy.DialProxy{ p.AddRoute(addrPortStr, &tcpproxy.DialProxy{
Addr: fmt.Sprintf("%s:%s", dest, port), Addr: fmt.Sprintf("%s:%s", dest, port),
}) })

View File

@ -4,14 +4,20 @@
package main package main
import ( import (
"context"
"fmt"
"io"
"net"
"net/netip" "net/netip"
"testing" "testing"
"time"
"github.com/gaissmai/bart" "github.com/gaissmai/bart"
"github.com/google/go-cmp/cmp"
"golang.org/x/net/dns/dnsmessage" "golang.org/x/net/dns/dnsmessage"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/cmd/natc/ippool" "tailscale.com/cmd/natc/ippool"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/util/must"
) )
func prefixEqual(a, b netip.Prefix) bool { func prefixEqual(a, b netip.Prefix) bool {
@ -41,22 +47,86 @@ func TestULA(t *testing.T) {
} }
} }
type recordingPacketConn struct {
writes [][]byte
}
func (w *recordingPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
w.writes = append(w.writes, b)
return len(b), nil
}
func (w *recordingPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
return 0, nil, io.EOF
}
func (w *recordingPacketConn) Close() error {
return nil
}
func (w *recordingPacketConn) LocalAddr() net.Addr {
return nil
}
func (w *recordingPacketConn) RemoteAddr() net.Addr {
return nil
}
func (w *recordingPacketConn) SetDeadline(t time.Time) error {
return nil
}
func (w *recordingPacketConn) SetReadDeadline(t time.Time) error {
return nil
}
func (w *recordingPacketConn) SetWriteDeadline(t time.Time) error {
return nil
}
type resolver struct {
resolves map[string][]netip.Addr
fails map[string]bool
}
func (r *resolver) LookupNetIP(ctx context.Context, _net, host string) ([]netip.Addr, error) {
if addrs, ok := r.resolves[host]; ok {
return addrs, nil
}
if _, ok := r.fails[host]; ok {
return nil, &net.DNSError{IsTimeout: false, IsNotFound: false, Name: host, IsTemporary: true}
}
return nil, &net.DNSError{IsNotFound: true, Name: host}
}
type whois struct {
peers map[string]*apitype.WhoIsResponse
}
func (w *whois) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) {
addr := netip.MustParseAddrPort(remoteAddr).Addr().String()
if peer, ok := w.peers[addr]; ok {
return peer, nil
}
return nil, fmt.Errorf("peer not found")
}
func TestDNSResponse(t *testing.T) { func TestDNSResponse(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
questions []dnsmessage.Question questions []dnsmessage.Question
addrs []netip.Addr
wantEmpty bool wantEmpty bool
wantAnswers []struct { wantAnswers []struct {
name string name string
qType dnsmessage.Type qType dnsmessage.Type
addr netip.Addr addr netip.Addr
} }
wantNXDOMAIN bool
wantIgnored bool
}{ }{
{ {
name: "empty_request", name: "empty_request",
questions: []dnsmessage.Question{}, questions: []dnsmessage.Question{},
addrs: []netip.Addr{},
wantEmpty: false, wantEmpty: false,
wantAnswers: nil, wantAnswers: nil,
}, },
@ -69,7 +139,6 @@ func TestDNSResponse(t *testing.T) {
Class: dnsmessage.ClassINET, Class: dnsmessage.ClassINET,
}, },
}, },
addrs: []netip.Addr{netip.MustParseAddr("100.64.1.5")},
wantAnswers: []struct { wantAnswers: []struct {
name string name string
qType dnsmessage.Type qType dnsmessage.Type
@ -78,7 +147,7 @@ func TestDNSResponse(t *testing.T) {
{ {
name: "example.com.", name: "example.com.",
qType: dnsmessage.TypeA, qType: dnsmessage.TypeA,
addr: netip.MustParseAddr("100.64.1.5"), addr: netip.MustParseAddr("100.64.0.0"),
}, },
}, },
}, },
@ -91,7 +160,6 @@ func TestDNSResponse(t *testing.T) {
Class: dnsmessage.ClassINET, Class: dnsmessage.ClassINET,
}, },
}, },
addrs: []netip.Addr{netip.MustParseAddr("fd7a:115c:a1e0:a99c:0001:0505:0505:0505")},
wantAnswers: []struct { wantAnswers: []struct {
name string name string
qType dnsmessage.Type qType dnsmessage.Type
@ -100,7 +168,7 @@ func TestDNSResponse(t *testing.T) {
{ {
name: "example.com.", name: "example.com.",
qType: dnsmessage.TypeAAAA, qType: dnsmessage.TypeAAAA,
addr: netip.MustParseAddr("fd7a:115c:a1e0:a99c:0001:0505:0505:0505"), addr: netip.MustParseAddr("fd7a:115c:a1e0::"),
}, },
}, },
}, },
@ -113,7 +181,6 @@ func TestDNSResponse(t *testing.T) {
Class: dnsmessage.ClassINET, Class: dnsmessage.ClassINET,
}, },
}, },
addrs: []netip.Addr{},
wantAnswers: nil, wantAnswers: nil,
}, },
{ {
@ -125,89 +192,210 @@ func TestDNSResponse(t *testing.T) {
Class: dnsmessage.ClassINET, Class: dnsmessage.ClassINET,
}, },
}, },
addrs: []netip.Addr{},
wantAnswers: nil, wantAnswers: nil,
}, },
{
name: "nxdomain",
questions: []dnsmessage.Question{
{
Name: dnsmessage.MustNewName("noexist.example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
},
wantNXDOMAIN: true,
},
{
name: "servfail",
questions: []dnsmessage.Question{
{
Name: dnsmessage.MustNewName("fail.example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
},
wantEmpty: true, // TODO: pass through instead?
},
{
name: "ignored",
questions: []dnsmessage.Question{
{
Name: dnsmessage.MustNewName("ignore.example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
},
wantAnswers: []struct {
name string
qType dnsmessage.Type
addr netip.Addr
}{
{
name: "ignore.example.com.",
qType: dnsmessage.TypeA,
addr: netip.MustParseAddr("8.8.4.4"),
},
},
wantIgnored: true,
},
} }
var rpc recordingPacketConn
remoteAddr := must.Get(net.ResolveUDPAddr("udp", "100.64.254.1:12345"))
routes, dnsAddr, addrPool := calculateAddresses([]netip.Prefix{netip.MustParsePrefix("10.64.0.0/24")})
v6ULA := ula(1)
c := connector{
resolver: &resolver{
resolves: map[string][]netip.Addr{
"example.com.": {
netip.MustParseAddr("8.8.8.8"),
netip.MustParseAddr("2001:4860:4860::8888"),
},
"ignore.example.com.": {
netip.MustParseAddr("8.8.4.4"),
},
},
fails: map[string]bool{
"fail.example.com.": true,
},
},
whois: &whois{
peers: map[string]*apitype.WhoIsResponse{
"100.64.254.1": {
Node: &tailcfg.Node{ID: 123},
},
},
},
ignoreDsts: &bart.Table[bool]{},
routes: routes,
v6ULA: v6ULA,
ipPool: &ippool.IPPool{V6ULA: v6ULA, IPSet: addrPool},
dnsAddr: dnsAddr,
}
c.ignoreDsts.Insert(netip.MustParsePrefix("8.8.4.4/32"), true)
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
req := &dnsmessage.Message{ rb := dnsmessage.NewBuilder(nil,
Header: dnsmessage.Header{ dnsmessage.Header{
ID: 1234, ID: 1234,
}, },
Questions: tc.questions, )
must.Do(rb.StartQuestions())
for _, q := range tc.questions {
rb.Question(q)
} }
resp, err := dnsResponse(req, tc.addrs) c.handleDNS(&rpc, must.Get(rb.Finish()), remoteAddr)
writes := rpc.writes
rpc.writes = rpc.writes[:0]
if tc.wantEmpty {
if len(writes) != 0 {
t.Errorf("handleDNS() returned non-empty response when expected empty")
}
return
}
if !tc.wantEmpty && len(writes) != 1 {
t.Fatalf("handleDNS() returned an unexpected number of responses: %d, want 1", len(writes))
}
resp := writes[0]
var msg dnsmessage.Message
err := msg.Unpack(resp)
if err != nil { if err != nil {
t.Fatalf("dnsResponse() error = %v", err) t.Fatalf("Failed to unpack response: %v", err)
} }
if tc.wantEmpty && len(resp) != 0 { if !msg.Header.Response {
t.Errorf("dnsResponse() returned non-empty response when expected empty") t.Errorf("Response header is not set")
} }
if !tc.wantEmpty && len(resp) == 0 { if msg.Header.ID != 1234 {
t.Errorf("dnsResponse() returned empty response when expected non-empty") t.Errorf("Response ID = %d, want %d", msg.Header.ID, 1234)
} }
if len(resp) > 0 { if len(tc.wantAnswers) > 0 {
var msg dnsmessage.Message if len(msg.Answers) != len(tc.wantAnswers) {
err = msg.Unpack(resp) t.Errorf("got %d answers, want %d:\n%s", len(msg.Answers), len(tc.wantAnswers), msg.GoString())
if err != nil { } else {
t.Fatalf("Failed to unpack response: %v", err) for i, want := range tc.wantAnswers {
} ans := msg.Answers[i]
if !msg.Header.Response { gotName := ans.Header.Name.String()
t.Errorf("Response header is not set") if gotName != want.name {
} t.Errorf("answer[%d] name = %s, want %s", i, gotName, want.name)
}
if msg.Header.ID != req.Header.ID { if ans.Header.Type != want.qType {
t.Errorf("Response ID = %d, want %d", msg.Header.ID, req.Header.ID) t.Errorf("answer[%d] type = %v, want %v", i, ans.Header.Type, want.qType)
} }
if len(tc.wantAnswers) > 0 { switch want.qType {
if len(msg.Answers) != len(tc.wantAnswers) { case dnsmessage.TypeA:
t.Errorf("got %d answers, want %d", len(msg.Answers), len(tc.wantAnswers)) if ans.Body.(*dnsmessage.AResource) == nil {
} else { t.Errorf("answer[%d] not an A record", i)
for i, want := range tc.wantAnswers { continue
ans := msg.Answers[i]
gotName := ans.Header.Name.String()
if gotName != want.name {
t.Errorf("answer[%d] name = %s, want %s", i, gotName, want.name)
} }
resource := ans.Body.(*dnsmessage.AResource)
gotIP := netip.AddrFrom4([4]byte(resource.A))
if ans.Header.Type != want.qType { var ips []netip.Addr
t.Errorf("answer[%d] type = %v, want %v", i, ans.Header.Type, want.qType) if tc.wantIgnored {
ips = must.Get(c.resolver.LookupNetIP(t.Context(), "ip4", want.name))
} else {
ips = must.Get(c.ipPool.IPForDomain(tailcfg.NodeID(123), want.name))
} }
var wantIP netip.Addr
var gotIP netip.Addr for _, ip := range ips {
switch want.qType { if ip.Is4() {
case dnsmessage.TypeA: wantIP = ip
if ans.Body.(*dnsmessage.AResource) == nil { break
t.Errorf("answer[%d] not an A record", i)
continue
} }
resource := ans.Body.(*dnsmessage.AResource)
gotIP = netip.AddrFrom4([4]byte(resource.A))
case dnsmessage.TypeAAAA:
if ans.Body.(*dnsmessage.AAAAResource) == nil {
t.Errorf("answer[%d] not an AAAA record", i)
continue
}
resource := ans.Body.(*dnsmessage.AAAAResource)
gotIP = netip.AddrFrom16([16]byte(resource.AAAA))
} }
if gotIP != wantIP {
t.Errorf("answer[%d] IP = %s, want %s", i, gotIP, wantIP)
}
case dnsmessage.TypeAAAA:
if ans.Body.(*dnsmessage.AAAAResource) == nil {
t.Errorf("answer[%d] not an AAAA record", i)
continue
}
resource := ans.Body.(*dnsmessage.AAAAResource)
gotIP := netip.AddrFrom16([16]byte(resource.AAAA))
if gotIP != want.addr { var ips []netip.Addr
t.Errorf("answer[%d] IP = %s, want %s", i, gotIP, want.addr) if tc.wantIgnored {
ips = must.Get(c.resolver.LookupNetIP(t.Context(), "ip6", want.name))
} else {
ips = must.Get(c.ipPool.IPForDomain(tailcfg.NodeID(123), want.name))
}
var wantIP netip.Addr
for _, ip := range ips {
if ip.Is6() {
wantIP = ip
break
}
}
if gotIP != wantIP {
t.Errorf("answer[%d] IP = %s, want %s", i, gotIP, wantIP)
} }
} }
} }
} }
} }
if tc.wantNXDOMAIN {
if msg.RCode != dnsmessage.RCodeNameError {
t.Errorf("expected NXDOMAIN, got %v", msg.RCode)
}
if len(msg.Answers) != 0 {
t.Errorf("expected no answers, got %d", len(msg.Answers))
}
}
}) })
} }
} }
@ -257,53 +445,3 @@ func TestIgnoreDestination(t *testing.T) {
}) })
} }
} }
func TestConnectorGenerateDNSResponse(t *testing.T) {
v6ULA := netip.MustParsePrefix("fd7a:115c:a1e0:a99c:0001::/80")
routes, dnsAddr, addrPool := calculateAddresses([]netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")})
c := &connector{
v6ULA: v6ULA,
ipPool: &ippool.IPPool{V6ULA: v6ULA, IPSet: addrPool},
routes: routes,
dnsAddr: dnsAddr,
}
req := &dnsmessage.Message{
Header: dnsmessage.Header{ID: 1234},
Questions: []dnsmessage.Question{
{
Name: dnsmessage.MustNewName("example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
},
}
nodeID := tailcfg.NodeID(12345)
resp1, err := c.generateDNSResponse(req, nodeID)
if err != nil {
t.Fatalf("generateDNSResponse() error = %v", err)
}
if len(resp1) == 0 {
t.Fatalf("generateDNSResponse() returned empty response")
}
resp2, err := c.generateDNSResponse(req, nodeID)
if err != nil {
t.Fatalf("generateDNSResponse() second call error = %v", err)
}
if !cmp.Equal(resp1, resp2) {
t.Errorf("generateDNSResponse() responses differ between calls")
}
var msg dnsmessage.Message
err = msg.Unpack(resp1)
if err != nil {
t.Fatalf("dnsmessage Unpack error = %v", err)
}
if len(msg.Answers) != 1 {
t.Fatalf("expected 1 answer, got: %d", len(msg.Answers))
}
}