cmd/natc: fix handling of upstream and downstream nxdomain

Ensure that the upstream is always queried, so that if upstream is going
to NXDOMAIN natc will also return NXDOMAIN rather than returning address
allocations.

At this time both IPv4 and IPv6 are still returned if upstream has a
result, regardless of upstream support - this is ~ok as we're proxying.

Rewrite the tests to be once again slightly closer to integration tests,
but they're still very rough and in need of a refactor.

Further refactors are probably needed implementation side too, as this
removed rather than added units.

Updates #15367

Signed-off-by: James Tucker <james@tailscale.com>
This commit is contained in:
James Tucker 2025-04-01 18:52:45 -07:00 committed by James Tucker
parent fb96137d79
commit 025fe72448
2 changed files with 379 additions and 252 deletions

View File

@ -26,14 +26,15 @@ import (
"go4.org/netipx"
"golang.org/x/net/dns/dnsmessage"
"tailscale.com/client/local"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/cmd/natc/ippool"
"tailscale.com/envknob"
"tailscale.com/hostinfo"
"tailscale.com/ipn"
"tailscale.com/net/netutil"
"tailscale.com/tailcfg"
"tailscale.com/tsnet"
"tailscale.com/tsweb"
"tailscale.com/util/mak"
"tailscale.com/util/must"
"tailscale.com/wgengine/netstack"
)
@ -148,14 +149,15 @@ func main() {
v6ULA := ula(uint16(*siteID))
c := &connector{
ts: ts,
lc: lc,
whois: lc,
v6ULA: v6ULA,
ignoreDsts: ignoreDstTable,
ipPool: &ippool.IPPool{V6ULA: v6ULA, IPSet: addrPool},
routes: routes,
dnsAddr: dnsAddr,
resolver: net.DefaultResolver,
}
c.run(ctx)
c.run(ctx, lc)
}
func calculateAddresses(prefixes []netip.Prefix) (*netipx.IPSet, netip.Addr, *netipx.IPSet) {
@ -170,12 +172,20 @@ func calculateAddresses(prefixes []netip.Prefix) (*netipx.IPSet, netip.Addr, *ne
return routesToAdvertise, dnsAddr, addrPool
}
type lookupNetIPer interface {
LookupNetIP(ctx context.Context, net, host string) ([]netip.Addr, error)
}
type whoiser interface {
WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error)
}
type connector struct {
// ts is the tsnet.Server used to host the connector.
ts *tsnet.Server
// lc is the local.Client used to interact with the tsnet.Server hosting this
// whois is the local.Client used to interact with the tsnet.Server hosting this
// connector.
lc *local.Client
whois whoiser
// dnsAddr is the IPv4 address to listen on for DNS requests. It is used to
// prevent the app connector from assigning it to a domain.
@ -197,7 +207,11 @@ type connector struct {
// natc behavior, which would return a dummy ip address pointing at natc).
ignoreDsts *bart.Table[bool]
// ipPool contains the per-peer IPv4 address assignments.
ipPool *ippool.IPPool
// resolver is used to lookup IP addresses for DNS queries.
resolver lookupNetIPer
}
// v6ULA is the ULA prefix used by the app connector to assign IPv6 addresses.
@ -217,8 +231,8 @@ func ula(siteID uint16) netip.Prefix {
//
// The passed in context is only used for the initial setup. The connector runs
// forever.
func (c *connector) run(ctx context.Context) {
if _, err := c.lc.EditPrefs(ctx, &ipn.MaskedPrefs{
func (c *connector) run(ctx context.Context, lc *local.Client) {
if _, err := lc.EditPrefs(ctx, &ipn.MaskedPrefs{
AdvertiseRoutesSet: true,
Prefs: ipn.Prefs{
AdvertiseRoutes: append(c.routes.Prefixes(), c.v6ULA),
@ -251,26 +265,6 @@ func (c *connector) serveDNS() {
}
}
func lookupDestinationIP(domain string) ([]netip.Addr, error) {
netIPs, err := net.LookupIP(domain)
if err != nil {
var dnsError *net.DNSError
if errors.As(err, &dnsError) && dnsError.IsNotFound {
return nil, nil
} else {
return nil, err
}
}
var addrs []netip.Addr
for _, ip := range netIPs {
a, ok := netip.AddrFromSlice(ip)
if ok {
addrs = append(addrs, a)
}
}
return addrs, nil
}
// handleDNS handles a DNS request to the app connector.
// It generates a response based on the request and the node that sent it.
//
@ -285,7 +279,7 @@ func lookupDestinationIP(domain string) ([]netip.Addr, error) {
func (c *connector) handleDNS(pc net.PacketConn, buf []byte, remoteAddr *net.UDPAddr) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
who, err := c.lc.WhoIs(ctx, remoteAddr.String())
who, err := c.whois.WhoIs(ctx, remoteAddr.String())
if err != nil {
log.Printf("HandleDNS(remote=%s): WhoIs failed: %v\n", remoteAddr.String(), err)
return
@ -298,49 +292,122 @@ func (c *connector) handleDNS(pc net.PacketConn, buf []byte, remoteAddr *net.UDP
return
}
// If there are destination ips that we don't want to route, we
// have to do a dns lookup here to find the destination ip.
if c.ignoreDsts != nil {
if len(msg.Questions) > 0 {
q := msg.Questions[0]
switch q.Type {
case dnsmessage.TypeAAAA, dnsmessage.TypeA:
dstAddrs, err := lookupDestinationIP(q.Name.String())
var resolves map[string][]netip.Addr
var addrQCount int
for _, q := range msg.Questions {
if q.Type != dnsmessage.TypeA && q.Type != dnsmessage.TypeAAAA {
continue
}
addrQCount++
if _, ok := resolves[q.Name.String()]; !ok {
addrs, err := c.resolver.LookupNetIP(ctx, "ip", q.Name.String())
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
continue
}
if err != nil {
log.Printf("HandleDNS(remote=%s): lookup destination failed: %v\n", remoteAddr.String(), err)
return
}
if c.ignoreDestination(dstAddrs) {
bs, err := dnsResponse(&msg, dstAddrs)
// TODO (fran): treat as SERVFAIL
// Note: If _any_ destination is ignored, pass through all of the resolved
// addresses as-is.
//
// This could result in some odd split-routing if there was a mix of
// ignored and non-ignored addresses, but it's currently the user
// preferred behavior.
if !c.ignoreDestination(addrs) {
addrs, err = c.ipPool.IPForDomain(who.Node.ID, q.Name.String())
if err != nil {
log.Printf("HandleDNS(remote=%s): generate ignore response failed: %v\n", remoteAddr.String(), err)
return
}
_, err = pc.WriteTo(bs, remoteAddr)
if err != nil {
log.Printf("HandleDNS(remote=%s): write failed: %v\n", remoteAddr.String(), err)
}
log.Printf("HandleDNS(remote=%s): lookup destination failed: %v\n", remoteAddr.String(), err)
return
}
}
mak.Set(&resolves, q.Name.String(), addrs)
}
}
// None of the destination IP addresses match an ignore destination prefix, do
// the natc thing.
resp, err := c.generateDNSResponse(&msg, who.Node.ID)
// TODO (fran): treat as SERVFAIL
rcode := dnsmessage.RCodeSuccess
if addrQCount > 0 && len(resolves) == 0 {
rcode = dnsmessage.RCodeNameError
}
b := dnsmessage.NewBuilder(nil,
dnsmessage.Header{
ID: msg.Header.ID,
Response: true,
Authoritative: true,
RCode: rcode,
})
b.EnableCompression()
if err := b.StartQuestions(); err != nil {
log.Printf("HandleDNS(remote=%s): dnsmessage start questions failed: %v\n", remoteAddr.String(), err)
return
}
for _, q := range msg.Questions {
b.Question(q)
}
if err := b.StartAnswers(); err != nil {
log.Printf("HandleDNS(remote=%s): dnsmessage start answers failed: %v\n", remoteAddr.String(), err)
return
}
for _, q := range msg.Questions {
switch q.Type {
case dnsmessage.TypeSOA:
if err := b.SOAResource(
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600,
Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60},
); err != nil {
log.Printf("HandleDNS(remote=%s): dnsmessage SOA resource failed: %v\n", remoteAddr.String(), err)
return
}
case dnsmessage.TypeNS:
if err := b.NSResource(
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
dnsmessage.NSResource{NS: tsMBox},
); err != nil {
log.Printf("HandleDNS(remote=%s): dnsmessage NS resource failed: %v\n", remoteAddr.String(), err)
return
}
case dnsmessage.TypeAAAA:
for _, addr := range resolves[q.Name.String()] {
if !addr.Is6() {
continue
}
if err := b.AAAAResource(
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
dnsmessage.AAAAResource{AAAA: addr.As16()},
); err != nil {
log.Printf("HandleDNS(remote=%s): dnsmessage AAAA resource failed: %v\n", remoteAddr.String(), err)
return
}
}
case dnsmessage.TypeA:
for _, addr := range resolves[q.Name.String()] {
if !addr.Is4() {
continue
}
if err := b.AResource(
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
dnsmessage.AResource{A: addr.As4()},
); err != nil {
log.Printf("HandleDNS(remote=%s): dnsmessage A resource failed: %v\n", remoteAddr.String(), err)
return
}
}
}
}
out, err := b.Finish()
if err != nil {
log.Printf("HandleDNS(remote=%s): connector handling failed: %v\n", remoteAddr.String(), err)
log.Printf("HandleDNS(remote=%s): dnsmessage finish failed: %v\n", remoteAddr.String(), err)
return
}
// TODO (fran): treat as NXDOMAIN
if len(resp) == 0 {
return
}
// This connector handled the DNS request
_, err = pc.WriteTo(resp, remoteAddr)
_, err = pc.WriteTo(out, remoteAddr)
if err != nil {
log.Printf("HandleDNS(remote=%s): write failed: %v\n", remoteAddr.String(), err)
}
@ -352,89 +419,6 @@ func (c *connector) handleDNS(pc net.PacketConn, buf []byte, remoteAddr *net.UDP
// to indicate that it is a fully qualified domain name.
var tsMBox = dnsmessage.MustNewName("support.tailscale.com.")
// generateDNSResponse generates a DNS response for the given request. The from
// argument is the NodeID of the node that sent the request.
func (c *connector) generateDNSResponse(req *dnsmessage.Message, from tailcfg.NodeID) ([]byte, error) {
var addrs []netip.Addr
if len(req.Questions) > 0 {
switch req.Questions[0].Type {
case dnsmessage.TypeAAAA, dnsmessage.TypeA:
var err error
addrs, err = c.ipPool.IPForDomain(from, req.Questions[0].Name.String())
if err != nil {
return nil, err
}
}
}
return dnsResponse(req, addrs)
}
// dnsResponse makes a DNS response for the natc. If the dnsmessage is requesting TypeAAAA
// or TypeA the provided addrs of the requested type will be used.
func dnsResponse(req *dnsmessage.Message, addrs []netip.Addr) ([]byte, error) {
b := dnsmessage.NewBuilder(nil,
dnsmessage.Header{
ID: req.Header.ID,
Response: true,
Authoritative: true,
})
b.EnableCompression()
if len(req.Questions) == 0 {
return b.Finish()
}
q := req.Questions[0]
if err := b.StartQuestions(); err != nil {
return nil, err
}
if err := b.Question(q); err != nil {
return nil, err
}
if err := b.StartAnswers(); err != nil {
return nil, err
}
switch q.Type {
case dnsmessage.TypeAAAA, dnsmessage.TypeA:
want6 := q.Type == dnsmessage.TypeAAAA
for _, ip := range addrs {
if want6 != ip.Is6() {
continue
}
if want6 {
if err := b.AAAAResource(
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 5},
dnsmessage.AAAAResource{AAAA: ip.As16()},
); err != nil {
return nil, err
}
} else {
if err := b.AResource(
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 5},
dnsmessage.AResource{A: ip.As4()},
); err != nil {
return nil, err
}
}
}
case dnsmessage.TypeSOA:
if err := b.SOAResource(
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600,
Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60},
); err != nil {
return nil, err
}
case dnsmessage.TypeNS:
if err := b.NSResource(
dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120},
dnsmessage.NSResource{NS: tsMBox},
); err != nil {
return nil, err
}
}
return b.Finish()
}
// handleTCPFlow handles a TCP flow from the given source to the given
// destination. It uses the source address to determine the node that sent the
// request and the destination address to determine the domain that the request
@ -443,7 +427,7 @@ func dnsResponse(req *dnsmessage.Message, addrs []netip.Addr) ([]byte, error) {
func (c *connector) handleTCPFlow(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
who, err := c.lc.WhoIs(ctx, src.Addr().String())
who, err := c.whois.WhoIs(ctx, src.Addr().String())
cancel()
if err != nil {
log.Printf("HandleTCPFlow: WhoIs failed: %v\n", err)
@ -461,6 +445,9 @@ func (c *connector) handleTCPFlow(src, dst netip.AddrPort) (handler func(net.Con
// ignoreDestination reports whether any of the provided dstAddrs match the prefixes configured
// in --ignore-destinations
func (c *connector) ignoreDestination(dstAddrs []netip.Addr) bool {
if c.ignoreDsts == nil {
return false
}
for _, a := range dstAddrs {
if _, ok := c.ignoreDsts.Lookup(a); ok {
return true
@ -488,6 +475,8 @@ func proxyTCPConn(c net.Conn, dest string) {
return netutil.NewOneConnListener(c, nil), nil
},
}
// XXX(raggi): if the connection here resolves to an ignored destination,
// the connection should be closed/failed.
p.AddRoute(addrPortStr, &tcpproxy.DialProxy{
Addr: fmt.Sprintf("%s:%s", dest, port),
})

View File

@ -4,14 +4,20 @@
package main
import (
"context"
"fmt"
"io"
"net"
"net/netip"
"testing"
"time"
"github.com/gaissmai/bart"
"github.com/google/go-cmp/cmp"
"golang.org/x/net/dns/dnsmessage"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/cmd/natc/ippool"
"tailscale.com/tailcfg"
"tailscale.com/util/must"
)
func prefixEqual(a, b netip.Prefix) bool {
@ -41,22 +47,86 @@ func TestULA(t *testing.T) {
}
}
type recordingPacketConn struct {
writes [][]byte
}
func (w *recordingPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
w.writes = append(w.writes, b)
return len(b), nil
}
func (w *recordingPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
return 0, nil, io.EOF
}
func (w *recordingPacketConn) Close() error {
return nil
}
func (w *recordingPacketConn) LocalAddr() net.Addr {
return nil
}
func (w *recordingPacketConn) RemoteAddr() net.Addr {
return nil
}
func (w *recordingPacketConn) SetDeadline(t time.Time) error {
return nil
}
func (w *recordingPacketConn) SetReadDeadline(t time.Time) error {
return nil
}
func (w *recordingPacketConn) SetWriteDeadline(t time.Time) error {
return nil
}
type resolver struct {
resolves map[string][]netip.Addr
fails map[string]bool
}
func (r *resolver) LookupNetIP(ctx context.Context, _net, host string) ([]netip.Addr, error) {
if addrs, ok := r.resolves[host]; ok {
return addrs, nil
}
if _, ok := r.fails[host]; ok {
return nil, &net.DNSError{IsTimeout: false, IsNotFound: false, Name: host, IsTemporary: true}
}
return nil, &net.DNSError{IsNotFound: true, Name: host}
}
type whois struct {
peers map[string]*apitype.WhoIsResponse
}
func (w *whois) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) {
addr := netip.MustParseAddrPort(remoteAddr).Addr().String()
if peer, ok := w.peers[addr]; ok {
return peer, nil
}
return nil, fmt.Errorf("peer not found")
}
func TestDNSResponse(t *testing.T) {
tests := []struct {
name string
questions []dnsmessage.Question
addrs []netip.Addr
wantEmpty bool
wantAnswers []struct {
name string
qType dnsmessage.Type
addr netip.Addr
}
wantNXDOMAIN bool
wantIgnored bool
}{
{
name: "empty_request",
questions: []dnsmessage.Question{},
addrs: []netip.Addr{},
wantEmpty: false,
wantAnswers: nil,
},
@ -69,7 +139,6 @@ func TestDNSResponse(t *testing.T) {
Class: dnsmessage.ClassINET,
},
},
addrs: []netip.Addr{netip.MustParseAddr("100.64.1.5")},
wantAnswers: []struct {
name string
qType dnsmessage.Type
@ -78,7 +147,7 @@ func TestDNSResponse(t *testing.T) {
{
name: "example.com.",
qType: dnsmessage.TypeA,
addr: netip.MustParseAddr("100.64.1.5"),
addr: netip.MustParseAddr("100.64.0.0"),
},
},
},
@ -91,7 +160,6 @@ func TestDNSResponse(t *testing.T) {
Class: dnsmessage.ClassINET,
},
},
addrs: []netip.Addr{netip.MustParseAddr("fd7a:115c:a1e0:a99c:0001:0505:0505:0505")},
wantAnswers: []struct {
name string
qType dnsmessage.Type
@ -100,7 +168,7 @@ func TestDNSResponse(t *testing.T) {
{
name: "example.com.",
qType: dnsmessage.TypeAAAA,
addr: netip.MustParseAddr("fd7a:115c:a1e0:a99c:0001:0505:0505:0505"),
addr: netip.MustParseAddr("fd7a:115c:a1e0::"),
},
},
},
@ -113,7 +181,6 @@ func TestDNSResponse(t *testing.T) {
Class: dnsmessage.ClassINET,
},
},
addrs: []netip.Addr{},
wantAnswers: nil,
},
{
@ -125,36 +192,120 @@ func TestDNSResponse(t *testing.T) {
Class: dnsmessage.ClassINET,
},
},
addrs: []netip.Addr{},
wantAnswers: nil,
},
{
name: "nxdomain",
questions: []dnsmessage.Question{
{
Name: dnsmessage.MustNewName("noexist.example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
},
wantNXDOMAIN: true,
},
{
name: "servfail",
questions: []dnsmessage.Question{
{
Name: dnsmessage.MustNewName("fail.example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
},
wantEmpty: true, // TODO: pass through instead?
},
{
name: "ignored",
questions: []dnsmessage.Question{
{
Name: dnsmessage.MustNewName("ignore.example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
},
wantAnswers: []struct {
name string
qType dnsmessage.Type
addr netip.Addr
}{
{
name: "ignore.example.com.",
qType: dnsmessage.TypeA,
addr: netip.MustParseAddr("8.8.4.4"),
},
},
wantIgnored: true,
},
}
var rpc recordingPacketConn
remoteAddr := must.Get(net.ResolveUDPAddr("udp", "100.64.254.1:12345"))
routes, dnsAddr, addrPool := calculateAddresses([]netip.Prefix{netip.MustParsePrefix("10.64.0.0/24")})
v6ULA := ula(1)
c := connector{
resolver: &resolver{
resolves: map[string][]netip.Addr{
"example.com.": {
netip.MustParseAddr("8.8.8.8"),
netip.MustParseAddr("2001:4860:4860::8888"),
},
"ignore.example.com.": {
netip.MustParseAddr("8.8.4.4"),
},
},
fails: map[string]bool{
"fail.example.com.": true,
},
},
whois: &whois{
peers: map[string]*apitype.WhoIsResponse{
"100.64.254.1": {
Node: &tailcfg.Node{ID: 123},
},
},
},
ignoreDsts: &bart.Table[bool]{},
routes: routes,
v6ULA: v6ULA,
ipPool: &ippool.IPPool{V6ULA: v6ULA, IPSet: addrPool},
dnsAddr: dnsAddr,
}
c.ignoreDsts.Insert(netip.MustParsePrefix("8.8.4.4/32"), true)
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := &dnsmessage.Message{
Header: dnsmessage.Header{
rb := dnsmessage.NewBuilder(nil,
dnsmessage.Header{
ID: 1234,
},
Questions: tc.questions,
)
must.Do(rb.StartQuestions())
for _, q := range tc.questions {
rb.Question(q)
}
resp, err := dnsResponse(req, tc.addrs)
if err != nil {
t.Fatalf("dnsResponse() error = %v", err)
c.handleDNS(&rpc, must.Get(rb.Finish()), remoteAddr)
writes := rpc.writes
rpc.writes = rpc.writes[:0]
if tc.wantEmpty {
if len(writes) != 0 {
t.Errorf("handleDNS() returned non-empty response when expected empty")
}
return
}
if tc.wantEmpty && len(resp) != 0 {
t.Errorf("dnsResponse() returned non-empty response when expected empty")
if !tc.wantEmpty && len(writes) != 1 {
t.Fatalf("handleDNS() returned an unexpected number of responses: %d, want 1", len(writes))
}
if !tc.wantEmpty && len(resp) == 0 {
t.Errorf("dnsResponse() returned empty response when expected non-empty")
}
if len(resp) > 0 {
resp := writes[0]
var msg dnsmessage.Message
err = msg.Unpack(resp)
err := msg.Unpack(resp)
if err != nil {
t.Fatalf("Failed to unpack response: %v", err)
}
@ -163,13 +314,13 @@ func TestDNSResponse(t *testing.T) {
t.Errorf("Response header is not set")
}
if msg.Header.ID != req.Header.ID {
t.Errorf("Response ID = %d, want %d", msg.Header.ID, req.Header.ID)
if msg.Header.ID != 1234 {
t.Errorf("Response ID = %d, want %d", msg.Header.ID, 1234)
}
if len(tc.wantAnswers) > 0 {
if len(msg.Answers) != len(tc.wantAnswers) {
t.Errorf("got %d answers, want %d", len(msg.Answers), len(tc.wantAnswers))
t.Errorf("got %d answers, want %d:\n%s", len(msg.Answers), len(tc.wantAnswers), msg.GoString())
} else {
for i, want := range tc.wantAnswers {
ans := msg.Answers[i]
@ -183,7 +334,6 @@ func TestDNSResponse(t *testing.T) {
t.Errorf("answer[%d] type = %v, want %v", i, ans.Header.Type, want.qType)
}
var gotIP netip.Addr
switch want.qType {
case dnsmessage.TypeA:
if ans.Body.(*dnsmessage.AResource) == nil {
@ -191,21 +341,59 @@ func TestDNSResponse(t *testing.T) {
continue
}
resource := ans.Body.(*dnsmessage.AResource)
gotIP = netip.AddrFrom4([4]byte(resource.A))
gotIP := netip.AddrFrom4([4]byte(resource.A))
var ips []netip.Addr
if tc.wantIgnored {
ips = must.Get(c.resolver.LookupNetIP(t.Context(), "ip4", want.name))
} else {
ips = must.Get(c.ipPool.IPForDomain(tailcfg.NodeID(123), want.name))
}
var wantIP netip.Addr
for _, ip := range ips {
if ip.Is4() {
wantIP = ip
break
}
}
if gotIP != wantIP {
t.Errorf("answer[%d] IP = %s, want %s", i, gotIP, wantIP)
}
case dnsmessage.TypeAAAA:
if ans.Body.(*dnsmessage.AAAAResource) == nil {
t.Errorf("answer[%d] not an AAAA record", i)
continue
}
resource := ans.Body.(*dnsmessage.AAAAResource)
gotIP = netip.AddrFrom16([16]byte(resource.AAAA))
gotIP := netip.AddrFrom16([16]byte(resource.AAAA))
var ips []netip.Addr
if tc.wantIgnored {
ips = must.Get(c.resolver.LookupNetIP(t.Context(), "ip6", want.name))
} else {
ips = must.Get(c.ipPool.IPForDomain(tailcfg.NodeID(123), want.name))
}
var wantIP netip.Addr
for _, ip := range ips {
if ip.Is6() {
wantIP = ip
break
}
}
if gotIP != wantIP {
t.Errorf("answer[%d] IP = %s, want %s", i, gotIP, wantIP)
}
}
}
}
}
if gotIP != want.addr {
t.Errorf("answer[%d] IP = %s, want %s", i, gotIP, want.addr)
}
}
if tc.wantNXDOMAIN {
if msg.RCode != dnsmessage.RCodeNameError {
t.Errorf("expected NXDOMAIN, got %v", msg.RCode)
}
if len(msg.Answers) != 0 {
t.Errorf("expected no answers, got %d", len(msg.Answers))
}
}
})
@ -257,53 +445,3 @@ func TestIgnoreDestination(t *testing.T) {
})
}
}
func TestConnectorGenerateDNSResponse(t *testing.T) {
v6ULA := netip.MustParsePrefix("fd7a:115c:a1e0:a99c:0001::/80")
routes, dnsAddr, addrPool := calculateAddresses([]netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")})
c := &connector{
v6ULA: v6ULA,
ipPool: &ippool.IPPool{V6ULA: v6ULA, IPSet: addrPool},
routes: routes,
dnsAddr: dnsAddr,
}
req := &dnsmessage.Message{
Header: dnsmessage.Header{ID: 1234},
Questions: []dnsmessage.Question{
{
Name: dnsmessage.MustNewName("example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
},
}
nodeID := tailcfg.NodeID(12345)
resp1, err := c.generateDNSResponse(req, nodeID)
if err != nil {
t.Fatalf("generateDNSResponse() error = %v", err)
}
if len(resp1) == 0 {
t.Fatalf("generateDNSResponse() returned empty response")
}
resp2, err := c.generateDNSResponse(req, nodeID)
if err != nil {
t.Fatalf("generateDNSResponse() second call error = %v", err)
}
if !cmp.Equal(resp1, resp2) {
t.Errorf("generateDNSResponse() responses differ between calls")
}
var msg dnsmessage.Message
err = msg.Unpack(resp1)
if err != nil {
t.Fatalf("dnsmessage Unpack error = %v", err)
}
if len(msg.Answers) != 1 {
t.Fatalf("expected 1 answer, got: %d", len(msg.Answers))
}
}