cmd/natc: fix handling of upstream and downstream nxdomain

Ensure that the upstream is always queried, so that if upstream is going
to NXDOMAIN natc will also return NXDOMAIN rather than returning address
allocations.

At this time both IPv4 and IPv6 are still returned if upstream has a
result, regardless of upstream support - this is ~ok as we're proxying.

Rewrite the tests to be once again slightly closer to integration tests,
but they're still very rough and in need of a refactor.

Further refactors are probably needed implementation side too, as this
removed rather than added units.

Updates #15367

Signed-off-by: James Tucker <james@tailscale.com>
This commit is contained in:
James Tucker
2025-04-01 18:52:45 -07:00
committed by James Tucker
parent fb96137d79
commit 025fe72448
2 changed files with 379 additions and 252 deletions

View File

@@ -4,14 +4,20 @@
package main
import (
"context"
"fmt"
"io"
"net"
"net/netip"
"testing"
"time"
"github.com/gaissmai/bart"
"github.com/google/go-cmp/cmp"
"golang.org/x/net/dns/dnsmessage"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/cmd/natc/ippool"
"tailscale.com/tailcfg"
"tailscale.com/util/must"
)
func prefixEqual(a, b netip.Prefix) bool {
@@ -41,22 +47,86 @@ func TestULA(t *testing.T) {
}
}
type recordingPacketConn struct {
writes [][]byte
}
func (w *recordingPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
w.writes = append(w.writes, b)
return len(b), nil
}
func (w *recordingPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
return 0, nil, io.EOF
}
func (w *recordingPacketConn) Close() error {
return nil
}
func (w *recordingPacketConn) LocalAddr() net.Addr {
return nil
}
func (w *recordingPacketConn) RemoteAddr() net.Addr {
return nil
}
func (w *recordingPacketConn) SetDeadline(t time.Time) error {
return nil
}
func (w *recordingPacketConn) SetReadDeadline(t time.Time) error {
return nil
}
func (w *recordingPacketConn) SetWriteDeadline(t time.Time) error {
return nil
}
type resolver struct {
resolves map[string][]netip.Addr
fails map[string]bool
}
func (r *resolver) LookupNetIP(ctx context.Context, _net, host string) ([]netip.Addr, error) {
if addrs, ok := r.resolves[host]; ok {
return addrs, nil
}
if _, ok := r.fails[host]; ok {
return nil, &net.DNSError{IsTimeout: false, IsNotFound: false, Name: host, IsTemporary: true}
}
return nil, &net.DNSError{IsNotFound: true, Name: host}
}
type whois struct {
peers map[string]*apitype.WhoIsResponse
}
func (w *whois) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) {
addr := netip.MustParseAddrPort(remoteAddr).Addr().String()
if peer, ok := w.peers[addr]; ok {
return peer, nil
}
return nil, fmt.Errorf("peer not found")
}
func TestDNSResponse(t *testing.T) {
tests := []struct {
name string
questions []dnsmessage.Question
addrs []netip.Addr
wantEmpty bool
wantAnswers []struct {
name string
qType dnsmessage.Type
addr netip.Addr
}
wantNXDOMAIN bool
wantIgnored bool
}{
{
name: "empty_request",
questions: []dnsmessage.Question{},
addrs: []netip.Addr{},
wantEmpty: false,
wantAnswers: nil,
},
@@ -69,7 +139,6 @@ func TestDNSResponse(t *testing.T) {
Class: dnsmessage.ClassINET,
},
},
addrs: []netip.Addr{netip.MustParseAddr("100.64.1.5")},
wantAnswers: []struct {
name string
qType dnsmessage.Type
@@ -78,7 +147,7 @@ func TestDNSResponse(t *testing.T) {
{
name: "example.com.",
qType: dnsmessage.TypeA,
addr: netip.MustParseAddr("100.64.1.5"),
addr: netip.MustParseAddr("100.64.0.0"),
},
},
},
@@ -91,7 +160,6 @@ func TestDNSResponse(t *testing.T) {
Class: dnsmessage.ClassINET,
},
},
addrs: []netip.Addr{netip.MustParseAddr("fd7a:115c:a1e0:a99c:0001:0505:0505:0505")},
wantAnswers: []struct {
name string
qType dnsmessage.Type
@@ -100,7 +168,7 @@ func TestDNSResponse(t *testing.T) {
{
name: "example.com.",
qType: dnsmessage.TypeAAAA,
addr: netip.MustParseAddr("fd7a:115c:a1e0:a99c:0001:0505:0505:0505"),
addr: netip.MustParseAddr("fd7a:115c:a1e0::"),
},
},
},
@@ -113,7 +181,6 @@ func TestDNSResponse(t *testing.T) {
Class: dnsmessage.ClassINET,
},
},
addrs: []netip.Addr{},
wantAnswers: nil,
},
{
@@ -125,89 +192,210 @@ func TestDNSResponse(t *testing.T) {
Class: dnsmessage.ClassINET,
},
},
addrs: []netip.Addr{},
wantAnswers: nil,
},
{
name: "nxdomain",
questions: []dnsmessage.Question{
{
Name: dnsmessage.MustNewName("noexist.example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
},
wantNXDOMAIN: true,
},
{
name: "servfail",
questions: []dnsmessage.Question{
{
Name: dnsmessage.MustNewName("fail.example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
},
wantEmpty: true, // TODO: pass through instead?
},
{
name: "ignored",
questions: []dnsmessage.Question{
{
Name: dnsmessage.MustNewName("ignore.example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
},
wantAnswers: []struct {
name string
qType dnsmessage.Type
addr netip.Addr
}{
{
name: "ignore.example.com.",
qType: dnsmessage.TypeA,
addr: netip.MustParseAddr("8.8.4.4"),
},
},
wantIgnored: true,
},
}
var rpc recordingPacketConn
remoteAddr := must.Get(net.ResolveUDPAddr("udp", "100.64.254.1:12345"))
routes, dnsAddr, addrPool := calculateAddresses([]netip.Prefix{netip.MustParsePrefix("10.64.0.0/24")})
v6ULA := ula(1)
c := connector{
resolver: &resolver{
resolves: map[string][]netip.Addr{
"example.com.": {
netip.MustParseAddr("8.8.8.8"),
netip.MustParseAddr("2001:4860:4860::8888"),
},
"ignore.example.com.": {
netip.MustParseAddr("8.8.4.4"),
},
},
fails: map[string]bool{
"fail.example.com.": true,
},
},
whois: &whois{
peers: map[string]*apitype.WhoIsResponse{
"100.64.254.1": {
Node: &tailcfg.Node{ID: 123},
},
},
},
ignoreDsts: &bart.Table[bool]{},
routes: routes,
v6ULA: v6ULA,
ipPool: &ippool.IPPool{V6ULA: v6ULA, IPSet: addrPool},
dnsAddr: dnsAddr,
}
c.ignoreDsts.Insert(netip.MustParsePrefix("8.8.4.4/32"), true)
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := &dnsmessage.Message{
Header: dnsmessage.Header{
rb := dnsmessage.NewBuilder(nil,
dnsmessage.Header{
ID: 1234,
},
Questions: tc.questions,
)
must.Do(rb.StartQuestions())
for _, q := range tc.questions {
rb.Question(q)
}
resp, err := dnsResponse(req, tc.addrs)
c.handleDNS(&rpc, must.Get(rb.Finish()), remoteAddr)
writes := rpc.writes
rpc.writes = rpc.writes[:0]
if tc.wantEmpty {
if len(writes) != 0 {
t.Errorf("handleDNS() returned non-empty response when expected empty")
}
return
}
if !tc.wantEmpty && len(writes) != 1 {
t.Fatalf("handleDNS() returned an unexpected number of responses: %d, want 1", len(writes))
}
resp := writes[0]
var msg dnsmessage.Message
err := msg.Unpack(resp)
if err != nil {
t.Fatalf("dnsResponse() error = %v", err)
t.Fatalf("Failed to unpack response: %v", err)
}
if tc.wantEmpty && len(resp) != 0 {
t.Errorf("dnsResponse() returned non-empty response when expected empty")
if !msg.Header.Response {
t.Errorf("Response header is not set")
}
if !tc.wantEmpty && len(resp) == 0 {
t.Errorf("dnsResponse() returned empty response when expected non-empty")
if msg.Header.ID != 1234 {
t.Errorf("Response ID = %d, want %d", msg.Header.ID, 1234)
}
if len(resp) > 0 {
var msg dnsmessage.Message
err = msg.Unpack(resp)
if err != nil {
t.Fatalf("Failed to unpack response: %v", err)
}
if len(tc.wantAnswers) > 0 {
if len(msg.Answers) != len(tc.wantAnswers) {
t.Errorf("got %d answers, want %d:\n%s", len(msg.Answers), len(tc.wantAnswers), msg.GoString())
} else {
for i, want := range tc.wantAnswers {
ans := msg.Answers[i]
if !msg.Header.Response {
t.Errorf("Response header is not set")
}
gotName := ans.Header.Name.String()
if gotName != want.name {
t.Errorf("answer[%d] name = %s, want %s", i, gotName, want.name)
}
if msg.Header.ID != req.Header.ID {
t.Errorf("Response ID = %d, want %d", msg.Header.ID, req.Header.ID)
}
if ans.Header.Type != want.qType {
t.Errorf("answer[%d] type = %v, want %v", i, ans.Header.Type, want.qType)
}
if len(tc.wantAnswers) > 0 {
if len(msg.Answers) != len(tc.wantAnswers) {
t.Errorf("got %d answers, want %d", len(msg.Answers), len(tc.wantAnswers))
} else {
for i, want := range tc.wantAnswers {
ans := msg.Answers[i]
gotName := ans.Header.Name.String()
if gotName != want.name {
t.Errorf("answer[%d] name = %s, want %s", i, gotName, want.name)
switch want.qType {
case dnsmessage.TypeA:
if ans.Body.(*dnsmessage.AResource) == nil {
t.Errorf("answer[%d] not an A record", i)
continue
}
resource := ans.Body.(*dnsmessage.AResource)
gotIP := netip.AddrFrom4([4]byte(resource.A))
if ans.Header.Type != want.qType {
t.Errorf("answer[%d] type = %v, want %v", i, ans.Header.Type, want.qType)
var ips []netip.Addr
if tc.wantIgnored {
ips = must.Get(c.resolver.LookupNetIP(t.Context(), "ip4", want.name))
} else {
ips = must.Get(c.ipPool.IPForDomain(tailcfg.NodeID(123), want.name))
}
var gotIP netip.Addr
switch want.qType {
case dnsmessage.TypeA:
if ans.Body.(*dnsmessage.AResource) == nil {
t.Errorf("answer[%d] not an A record", i)
continue
var wantIP netip.Addr
for _, ip := range ips {
if ip.Is4() {
wantIP = ip
break
}
resource := ans.Body.(*dnsmessage.AResource)
gotIP = netip.AddrFrom4([4]byte(resource.A))
case dnsmessage.TypeAAAA:
if ans.Body.(*dnsmessage.AAAAResource) == nil {
t.Errorf("answer[%d] not an AAAA record", i)
continue
}
resource := ans.Body.(*dnsmessage.AAAAResource)
gotIP = netip.AddrFrom16([16]byte(resource.AAAA))
}
if gotIP != wantIP {
t.Errorf("answer[%d] IP = %s, want %s", i, gotIP, wantIP)
}
case dnsmessage.TypeAAAA:
if ans.Body.(*dnsmessage.AAAAResource) == nil {
t.Errorf("answer[%d] not an AAAA record", i)
continue
}
resource := ans.Body.(*dnsmessage.AAAAResource)
gotIP := netip.AddrFrom16([16]byte(resource.AAAA))
if gotIP != want.addr {
t.Errorf("answer[%d] IP = %s, want %s", i, gotIP, want.addr)
var ips []netip.Addr
if tc.wantIgnored {
ips = must.Get(c.resolver.LookupNetIP(t.Context(), "ip6", want.name))
} else {
ips = must.Get(c.ipPool.IPForDomain(tailcfg.NodeID(123), want.name))
}
var wantIP netip.Addr
for _, ip := range ips {
if ip.Is6() {
wantIP = ip
break
}
}
if gotIP != wantIP {
t.Errorf("answer[%d] IP = %s, want %s", i, gotIP, wantIP)
}
}
}
}
}
if tc.wantNXDOMAIN {
if msg.RCode != dnsmessage.RCodeNameError {
t.Errorf("expected NXDOMAIN, got %v", msg.RCode)
}
if len(msg.Answers) != 0 {
t.Errorf("expected no answers, got %d", len(msg.Answers))
}
}
})
}
}
@@ -257,53 +445,3 @@ func TestIgnoreDestination(t *testing.T) {
})
}
}
func TestConnectorGenerateDNSResponse(t *testing.T) {
v6ULA := netip.MustParsePrefix("fd7a:115c:a1e0:a99c:0001::/80")
routes, dnsAddr, addrPool := calculateAddresses([]netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")})
c := &connector{
v6ULA: v6ULA,
ipPool: &ippool.IPPool{V6ULA: v6ULA, IPSet: addrPool},
routes: routes,
dnsAddr: dnsAddr,
}
req := &dnsmessage.Message{
Header: dnsmessage.Header{ID: 1234},
Questions: []dnsmessage.Question{
{
Name: dnsmessage.MustNewName("example.com."),
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
},
},
}
nodeID := tailcfg.NodeID(12345)
resp1, err := c.generateDNSResponse(req, nodeID)
if err != nil {
t.Fatalf("generateDNSResponse() error = %v", err)
}
if len(resp1) == 0 {
t.Fatalf("generateDNSResponse() returned empty response")
}
resp2, err := c.generateDNSResponse(req, nodeID)
if err != nil {
t.Fatalf("generateDNSResponse() second call error = %v", err)
}
if !cmp.Equal(resp1, resp2) {
t.Errorf("generateDNSResponse() responses differ between calls")
}
var msg dnsmessage.Message
err = msg.Unpack(resp1)
if err != nil {
t.Fatalf("dnsmessage Unpack error = %v", err)
}
if len(msg.Answers) != 1 {
t.Fatalf("expected 1 answer, got: %d", len(msg.Answers))
}
}