diff --git a/cmd/natc/natc.go b/cmd/natc/natc.go index 270524879..a80e4a42a 100644 --- a/cmd/natc/natc.go +++ b/cmd/natc/natc.go @@ -26,14 +26,15 @@ import ( "go4.org/netipx" "golang.org/x/net/dns/dnsmessage" "tailscale.com/client/local" + "tailscale.com/client/tailscale/apitype" "tailscale.com/cmd/natc/ippool" "tailscale.com/envknob" "tailscale.com/hostinfo" "tailscale.com/ipn" "tailscale.com/net/netutil" - "tailscale.com/tailcfg" "tailscale.com/tsnet" "tailscale.com/tsweb" + "tailscale.com/util/mak" "tailscale.com/util/must" "tailscale.com/wgengine/netstack" ) @@ -148,14 +149,15 @@ func main() { v6ULA := ula(uint16(*siteID)) c := &connector{ ts: ts, - lc: lc, + whois: lc, v6ULA: v6ULA, ignoreDsts: ignoreDstTable, ipPool: &ippool.IPPool{V6ULA: v6ULA, IPSet: addrPool}, routes: routes, dnsAddr: dnsAddr, + resolver: net.DefaultResolver, } - c.run(ctx) + c.run(ctx, lc) } func calculateAddresses(prefixes []netip.Prefix) (*netipx.IPSet, netip.Addr, *netipx.IPSet) { @@ -170,12 +172,20 @@ func calculateAddresses(prefixes []netip.Prefix) (*netipx.IPSet, netip.Addr, *ne return routesToAdvertise, dnsAddr, addrPool } +type lookupNetIPer interface { + LookupNetIP(ctx context.Context, net, host string) ([]netip.Addr, error) +} + +type whoiser interface { + WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) +} + type connector struct { // ts is the tsnet.Server used to host the connector. ts *tsnet.Server - // lc is the local.Client used to interact with the tsnet.Server hosting this + // whois is the local.Client used to interact with the tsnet.Server hosting this // connector. - lc *local.Client + whois whoiser // dnsAddr is the IPv4 address to listen on for DNS requests. It is used to // prevent the app connector from assigning it to a domain. @@ -197,7 +207,11 @@ type connector struct { // natc behavior, which would return a dummy ip address pointing at natc). ignoreDsts *bart.Table[bool] + // ipPool contains the per-peer IPv4 address assignments. ipPool *ippool.IPPool + + // resolver is used to lookup IP addresses for DNS queries. + resolver lookupNetIPer } // v6ULA is the ULA prefix used by the app connector to assign IPv6 addresses. @@ -217,8 +231,8 @@ func ula(siteID uint16) netip.Prefix { // // The passed in context is only used for the initial setup. The connector runs // forever. -func (c *connector) run(ctx context.Context) { - if _, err := c.lc.EditPrefs(ctx, &ipn.MaskedPrefs{ +func (c *connector) run(ctx context.Context, lc *local.Client) { + if _, err := lc.EditPrefs(ctx, &ipn.MaskedPrefs{ AdvertiseRoutesSet: true, Prefs: ipn.Prefs{ AdvertiseRoutes: append(c.routes.Prefixes(), c.v6ULA), @@ -251,26 +265,6 @@ func (c *connector) serveDNS() { } } -func lookupDestinationIP(domain string) ([]netip.Addr, error) { - netIPs, err := net.LookupIP(domain) - if err != nil { - var dnsError *net.DNSError - if errors.As(err, &dnsError) && dnsError.IsNotFound { - return nil, nil - } else { - return nil, err - } - } - var addrs []netip.Addr - for _, ip := range netIPs { - a, ok := netip.AddrFromSlice(ip) - if ok { - addrs = append(addrs, a) - } - } - return addrs, nil -} - // handleDNS handles a DNS request to the app connector. // It generates a response based on the request and the node that sent it. // @@ -285,7 +279,7 @@ func lookupDestinationIP(domain string) ([]netip.Addr, error) { func (c *connector) handleDNS(pc net.PacketConn, buf []byte, remoteAddr *net.UDPAddr) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - who, err := c.lc.WhoIs(ctx, remoteAddr.String()) + who, err := c.whois.WhoIs(ctx, remoteAddr.String()) if err != nil { log.Printf("HandleDNS(remote=%s): WhoIs failed: %v\n", remoteAddr.String(), err) return @@ -298,49 +292,122 @@ func (c *connector) handleDNS(pc net.PacketConn, buf []byte, remoteAddr *net.UDP return } - // If there are destination ips that we don't want to route, we - // have to do a dns lookup here to find the destination ip. - if c.ignoreDsts != nil { - if len(msg.Questions) > 0 { - q := msg.Questions[0] - switch q.Type { - case dnsmessage.TypeAAAA, dnsmessage.TypeA: - dstAddrs, err := lookupDestinationIP(q.Name.String()) + var resolves map[string][]netip.Addr + var addrQCount int + for _, q := range msg.Questions { + if q.Type != dnsmessage.TypeA && q.Type != dnsmessage.TypeAAAA { + continue + } + addrQCount++ + if _, ok := resolves[q.Name.String()]; !ok { + addrs, err := c.resolver.LookupNetIP(ctx, "ip", q.Name.String()) + var dnsErr *net.DNSError + if errors.As(err, &dnsErr) && dnsErr.IsNotFound { + continue + } + if err != nil { + log.Printf("HandleDNS(remote=%s): lookup destination failed: %v\n", remoteAddr.String(), err) + return + } + // Note: If _any_ destination is ignored, pass through all of the resolved + // addresses as-is. + // + // This could result in some odd split-routing if there was a mix of + // ignored and non-ignored addresses, but it's currently the user + // preferred behavior. + if !c.ignoreDestination(addrs) { + addrs, err = c.ipPool.IPForDomain(who.Node.ID, q.Name.String()) if err != nil { log.Printf("HandleDNS(remote=%s): lookup destination failed: %v\n", remoteAddr.String(), err) return } - if c.ignoreDestination(dstAddrs) { - bs, err := dnsResponse(&msg, dstAddrs) - // TODO (fran): treat as SERVFAIL - if err != nil { - log.Printf("HandleDNS(remote=%s): generate ignore response failed: %v\n", remoteAddr.String(), err) - return - } - _, err = pc.WriteTo(bs, remoteAddr) - if err != nil { - log.Printf("HandleDNS(remote=%s): write failed: %v\n", remoteAddr.String(), err) - } + } + mak.Set(&resolves, q.Name.String(), addrs) + } + } + + rcode := dnsmessage.RCodeSuccess + if addrQCount > 0 && len(resolves) == 0 { + rcode = dnsmessage.RCodeNameError + } + + b := dnsmessage.NewBuilder(nil, + dnsmessage.Header{ + ID: msg.Header.ID, + Response: true, + Authoritative: true, + RCode: rcode, + }) + b.EnableCompression() + + if err := b.StartQuestions(); err != nil { + log.Printf("HandleDNS(remote=%s): dnsmessage start questions failed: %v\n", remoteAddr.String(), err) + return + } + + for _, q := range msg.Questions { + b.Question(q) + } + + if err := b.StartAnswers(); err != nil { + log.Printf("HandleDNS(remote=%s): dnsmessage start answers failed: %v\n", remoteAddr.String(), err) + return + } + + for _, q := range msg.Questions { + switch q.Type { + case dnsmessage.TypeSOA: + if err := b.SOAResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600, + Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60}, + ); err != nil { + log.Printf("HandleDNS(remote=%s): dnsmessage SOA resource failed: %v\n", remoteAddr.String(), err) + return + } + case dnsmessage.TypeNS: + if err := b.NSResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.NSResource{NS: tsMBox}, + ); err != nil { + log.Printf("HandleDNS(remote=%s): dnsmessage NS resource failed: %v\n", remoteAddr.String(), err) + return + } + case dnsmessage.TypeAAAA: + for _, addr := range resolves[q.Name.String()] { + if !addr.Is6() { + continue + } + if err := b.AAAAResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.AAAAResource{AAAA: addr.As16()}, + ); err != nil { + log.Printf("HandleDNS(remote=%s): dnsmessage AAAA resource failed: %v\n", remoteAddr.String(), err) + return + } + } + case dnsmessage.TypeA: + for _, addr := range resolves[q.Name.String()] { + if !addr.Is4() { + continue + } + if err := b.AResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.AResource{A: addr.As4()}, + ); err != nil { + log.Printf("HandleDNS(remote=%s): dnsmessage A resource failed: %v\n", remoteAddr.String(), err) return } } } } - // None of the destination IP addresses match an ignore destination prefix, do - // the natc thing. - resp, err := c.generateDNSResponse(&msg, who.Node.ID) - // TODO (fran): treat as SERVFAIL + out, err := b.Finish() if err != nil { - log.Printf("HandleDNS(remote=%s): connector handling failed: %v\n", remoteAddr.String(), err) + log.Printf("HandleDNS(remote=%s): dnsmessage finish failed: %v\n", remoteAddr.String(), err) return } - // TODO (fran): treat as NXDOMAIN - if len(resp) == 0 { - return - } - // This connector handled the DNS request - _, err = pc.WriteTo(resp, remoteAddr) + _, err = pc.WriteTo(out, remoteAddr) if err != nil { log.Printf("HandleDNS(remote=%s): write failed: %v\n", remoteAddr.String(), err) } @@ -352,89 +419,6 @@ func (c *connector) handleDNS(pc net.PacketConn, buf []byte, remoteAddr *net.UDP // to indicate that it is a fully qualified domain name. var tsMBox = dnsmessage.MustNewName("support.tailscale.com.") -// generateDNSResponse generates a DNS response for the given request. The from -// argument is the NodeID of the node that sent the request. -func (c *connector) generateDNSResponse(req *dnsmessage.Message, from tailcfg.NodeID) ([]byte, error) { - var addrs []netip.Addr - if len(req.Questions) > 0 { - switch req.Questions[0].Type { - case dnsmessage.TypeAAAA, dnsmessage.TypeA: - var err error - addrs, err = c.ipPool.IPForDomain(from, req.Questions[0].Name.String()) - if err != nil { - return nil, err - } - } - } - return dnsResponse(req, addrs) -} - -// dnsResponse makes a DNS response for the natc. If the dnsmessage is requesting TypeAAAA -// or TypeA the provided addrs of the requested type will be used. -func dnsResponse(req *dnsmessage.Message, addrs []netip.Addr) ([]byte, error) { - b := dnsmessage.NewBuilder(nil, - dnsmessage.Header{ - ID: req.Header.ID, - Response: true, - Authoritative: true, - }) - b.EnableCompression() - - if len(req.Questions) == 0 { - return b.Finish() - } - q := req.Questions[0] - if err := b.StartQuestions(); err != nil { - return nil, err - } - if err := b.Question(q); err != nil { - return nil, err - } - if err := b.StartAnswers(); err != nil { - return nil, err - } - switch q.Type { - case dnsmessage.TypeAAAA, dnsmessage.TypeA: - want6 := q.Type == dnsmessage.TypeAAAA - for _, ip := range addrs { - if want6 != ip.Is6() { - continue - } - if want6 { - if err := b.AAAAResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 5}, - dnsmessage.AAAAResource{AAAA: ip.As16()}, - ); err != nil { - return nil, err - } - } else { - if err := b.AResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 5}, - dnsmessage.AResource{A: ip.As4()}, - ); err != nil { - return nil, err - } - } - } - case dnsmessage.TypeSOA: - if err := b.SOAResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600, - Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60}, - ); err != nil { - return nil, err - } - case dnsmessage.TypeNS: - if err := b.NSResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.NSResource{NS: tsMBox}, - ); err != nil { - return nil, err - } - } - return b.Finish() -} - // handleTCPFlow handles a TCP flow from the given source to the given // destination. It uses the source address to determine the node that sent the // request and the destination address to determine the domain that the request @@ -443,7 +427,7 @@ func dnsResponse(req *dnsmessage.Message, addrs []netip.Addr) ([]byte, error) { func (c *connector) handleTCPFlow(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - who, err := c.lc.WhoIs(ctx, src.Addr().String()) + who, err := c.whois.WhoIs(ctx, src.Addr().String()) cancel() if err != nil { log.Printf("HandleTCPFlow: WhoIs failed: %v\n", err) @@ -461,6 +445,9 @@ func (c *connector) handleTCPFlow(src, dst netip.AddrPort) (handler func(net.Con // ignoreDestination reports whether any of the provided dstAddrs match the prefixes configured // in --ignore-destinations func (c *connector) ignoreDestination(dstAddrs []netip.Addr) bool { + if c.ignoreDsts == nil { + return false + } for _, a := range dstAddrs { if _, ok := c.ignoreDsts.Lookup(a); ok { return true @@ -488,6 +475,8 @@ func proxyTCPConn(c net.Conn, dest string) { return netutil.NewOneConnListener(c, nil), nil }, } + // XXX(raggi): if the connection here resolves to an ignored destination, + // the connection should be closed/failed. p.AddRoute(addrPortStr, &tcpproxy.DialProxy{ Addr: fmt.Sprintf("%s:%s", dest, port), }) diff --git a/cmd/natc/natc_test.go b/cmd/natc/natc_test.go index 09ade0a98..8fe38de1c 100644 --- a/cmd/natc/natc_test.go +++ b/cmd/natc/natc_test.go @@ -4,14 +4,20 @@ package main import ( + "context" + "fmt" + "io" + "net" "net/netip" "testing" + "time" "github.com/gaissmai/bart" - "github.com/google/go-cmp/cmp" "golang.org/x/net/dns/dnsmessage" + "tailscale.com/client/tailscale/apitype" "tailscale.com/cmd/natc/ippool" "tailscale.com/tailcfg" + "tailscale.com/util/must" ) func prefixEqual(a, b netip.Prefix) bool { @@ -41,22 +47,86 @@ func TestULA(t *testing.T) { } } +type recordingPacketConn struct { + writes [][]byte +} + +func (w *recordingPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { + w.writes = append(w.writes, b) + return len(b), nil +} + +func (w *recordingPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { + return 0, nil, io.EOF +} + +func (w *recordingPacketConn) Close() error { + return nil +} + +func (w *recordingPacketConn) LocalAddr() net.Addr { + return nil +} + +func (w *recordingPacketConn) RemoteAddr() net.Addr { + return nil +} + +func (w *recordingPacketConn) SetDeadline(t time.Time) error { + return nil +} + +func (w *recordingPacketConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (w *recordingPacketConn) SetWriteDeadline(t time.Time) error { + return nil +} + +type resolver struct { + resolves map[string][]netip.Addr + fails map[string]bool +} + +func (r *resolver) LookupNetIP(ctx context.Context, _net, host string) ([]netip.Addr, error) { + if addrs, ok := r.resolves[host]; ok { + return addrs, nil + } + if _, ok := r.fails[host]; ok { + return nil, &net.DNSError{IsTimeout: false, IsNotFound: false, Name: host, IsTemporary: true} + } + return nil, &net.DNSError{IsNotFound: true, Name: host} +} + +type whois struct { + peers map[string]*apitype.WhoIsResponse +} + +func (w *whois) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) { + addr := netip.MustParseAddrPort(remoteAddr).Addr().String() + if peer, ok := w.peers[addr]; ok { + return peer, nil + } + return nil, fmt.Errorf("peer not found") +} + func TestDNSResponse(t *testing.T) { tests := []struct { name string questions []dnsmessage.Question - addrs []netip.Addr wantEmpty bool wantAnswers []struct { name string qType dnsmessage.Type addr netip.Addr } + wantNXDOMAIN bool + wantIgnored bool }{ { name: "empty_request", questions: []dnsmessage.Question{}, - addrs: []netip.Addr{}, wantEmpty: false, wantAnswers: nil, }, @@ -69,7 +139,6 @@ func TestDNSResponse(t *testing.T) { Class: dnsmessage.ClassINET, }, }, - addrs: []netip.Addr{netip.MustParseAddr("100.64.1.5")}, wantAnswers: []struct { name string qType dnsmessage.Type @@ -78,7 +147,7 @@ func TestDNSResponse(t *testing.T) { { name: "example.com.", qType: dnsmessage.TypeA, - addr: netip.MustParseAddr("100.64.1.5"), + addr: netip.MustParseAddr("100.64.0.0"), }, }, }, @@ -91,7 +160,6 @@ func TestDNSResponse(t *testing.T) { Class: dnsmessage.ClassINET, }, }, - addrs: []netip.Addr{netip.MustParseAddr("fd7a:115c:a1e0:a99c:0001:0505:0505:0505")}, wantAnswers: []struct { name string qType dnsmessage.Type @@ -100,7 +168,7 @@ func TestDNSResponse(t *testing.T) { { name: "example.com.", qType: dnsmessage.TypeAAAA, - addr: netip.MustParseAddr("fd7a:115c:a1e0:a99c:0001:0505:0505:0505"), + addr: netip.MustParseAddr("fd7a:115c:a1e0::"), }, }, }, @@ -113,7 +181,6 @@ func TestDNSResponse(t *testing.T) { Class: dnsmessage.ClassINET, }, }, - addrs: []netip.Addr{}, wantAnswers: nil, }, { @@ -125,89 +192,210 @@ func TestDNSResponse(t *testing.T) { Class: dnsmessage.ClassINET, }, }, - addrs: []netip.Addr{}, wantAnswers: nil, }, + { + name: "nxdomain", + questions: []dnsmessage.Question{ + { + Name: dnsmessage.MustNewName("noexist.example.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + }, + wantNXDOMAIN: true, + }, + { + name: "servfail", + questions: []dnsmessage.Question{ + { + Name: dnsmessage.MustNewName("fail.example.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + }, + wantEmpty: true, // TODO: pass through instead? + }, + { + name: "ignored", + questions: []dnsmessage.Question{ + { + Name: dnsmessage.MustNewName("ignore.example.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + }, + wantAnswers: []struct { + name string + qType dnsmessage.Type + addr netip.Addr + }{ + { + name: "ignore.example.com.", + qType: dnsmessage.TypeA, + addr: netip.MustParseAddr("8.8.4.4"), + }, + }, + wantIgnored: true, + }, } + var rpc recordingPacketConn + remoteAddr := must.Get(net.ResolveUDPAddr("udp", "100.64.254.1:12345")) + + routes, dnsAddr, addrPool := calculateAddresses([]netip.Prefix{netip.MustParsePrefix("10.64.0.0/24")}) + v6ULA := ula(1) + c := connector{ + resolver: &resolver{ + resolves: map[string][]netip.Addr{ + "example.com.": { + netip.MustParseAddr("8.8.8.8"), + netip.MustParseAddr("2001:4860:4860::8888"), + }, + "ignore.example.com.": { + netip.MustParseAddr("8.8.4.4"), + }, + }, + fails: map[string]bool{ + "fail.example.com.": true, + }, + }, + whois: &whois{ + peers: map[string]*apitype.WhoIsResponse{ + "100.64.254.1": { + Node: &tailcfg.Node{ID: 123}, + }, + }, + }, + ignoreDsts: &bart.Table[bool]{}, + routes: routes, + v6ULA: v6ULA, + ipPool: &ippool.IPPool{V6ULA: v6ULA, IPSet: addrPool}, + dnsAddr: dnsAddr, + } + c.ignoreDsts.Insert(netip.MustParsePrefix("8.8.4.4/32"), true) + for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - req := &dnsmessage.Message{ - Header: dnsmessage.Header{ + rb := dnsmessage.NewBuilder(nil, + dnsmessage.Header{ ID: 1234, }, - Questions: tc.questions, + ) + must.Do(rb.StartQuestions()) + for _, q := range tc.questions { + rb.Question(q) } - resp, err := dnsResponse(req, tc.addrs) + c.handleDNS(&rpc, must.Get(rb.Finish()), remoteAddr) + + writes := rpc.writes + rpc.writes = rpc.writes[:0] + + if tc.wantEmpty { + if len(writes) != 0 { + t.Errorf("handleDNS() returned non-empty response when expected empty") + } + return + } + + if !tc.wantEmpty && len(writes) != 1 { + t.Fatalf("handleDNS() returned an unexpected number of responses: %d, want 1", len(writes)) + } + + resp := writes[0] + var msg dnsmessage.Message + err := msg.Unpack(resp) if err != nil { - t.Fatalf("dnsResponse() error = %v", err) + t.Fatalf("Failed to unpack response: %v", err) } - if tc.wantEmpty && len(resp) != 0 { - t.Errorf("dnsResponse() returned non-empty response when expected empty") + if !msg.Header.Response { + t.Errorf("Response header is not set") } - if !tc.wantEmpty && len(resp) == 0 { - t.Errorf("dnsResponse() returned empty response when expected non-empty") + if msg.Header.ID != 1234 { + t.Errorf("Response ID = %d, want %d", msg.Header.ID, 1234) } - if len(resp) > 0 { - var msg dnsmessage.Message - err = msg.Unpack(resp) - if err != nil { - t.Fatalf("Failed to unpack response: %v", err) - } + if len(tc.wantAnswers) > 0 { + if len(msg.Answers) != len(tc.wantAnswers) { + t.Errorf("got %d answers, want %d:\n%s", len(msg.Answers), len(tc.wantAnswers), msg.GoString()) + } else { + for i, want := range tc.wantAnswers { + ans := msg.Answers[i] - if !msg.Header.Response { - t.Errorf("Response header is not set") - } + gotName := ans.Header.Name.String() + if gotName != want.name { + t.Errorf("answer[%d] name = %s, want %s", i, gotName, want.name) + } - if msg.Header.ID != req.Header.ID { - t.Errorf("Response ID = %d, want %d", msg.Header.ID, req.Header.ID) - } + if ans.Header.Type != want.qType { + t.Errorf("answer[%d] type = %v, want %v", i, ans.Header.Type, want.qType) + } - if len(tc.wantAnswers) > 0 { - if len(msg.Answers) != len(tc.wantAnswers) { - t.Errorf("got %d answers, want %d", len(msg.Answers), len(tc.wantAnswers)) - } else { - for i, want := range tc.wantAnswers { - ans := msg.Answers[i] - - gotName := ans.Header.Name.String() - if gotName != want.name { - t.Errorf("answer[%d] name = %s, want %s", i, gotName, want.name) + switch want.qType { + case dnsmessage.TypeA: + if ans.Body.(*dnsmessage.AResource) == nil { + t.Errorf("answer[%d] not an A record", i) + continue } + resource := ans.Body.(*dnsmessage.AResource) + gotIP := netip.AddrFrom4([4]byte(resource.A)) - if ans.Header.Type != want.qType { - t.Errorf("answer[%d] type = %v, want %v", i, ans.Header.Type, want.qType) + var ips []netip.Addr + if tc.wantIgnored { + ips = must.Get(c.resolver.LookupNetIP(t.Context(), "ip4", want.name)) + } else { + ips = must.Get(c.ipPool.IPForDomain(tailcfg.NodeID(123), want.name)) } - - var gotIP netip.Addr - switch want.qType { - case dnsmessage.TypeA: - if ans.Body.(*dnsmessage.AResource) == nil { - t.Errorf("answer[%d] not an A record", i) - continue + var wantIP netip.Addr + for _, ip := range ips { + if ip.Is4() { + wantIP = ip + break } - resource := ans.Body.(*dnsmessage.AResource) - gotIP = netip.AddrFrom4([4]byte(resource.A)) - case dnsmessage.TypeAAAA: - if ans.Body.(*dnsmessage.AAAAResource) == nil { - t.Errorf("answer[%d] not an AAAA record", i) - continue - } - resource := ans.Body.(*dnsmessage.AAAAResource) - gotIP = netip.AddrFrom16([16]byte(resource.AAAA)) } + if gotIP != wantIP { + t.Errorf("answer[%d] IP = %s, want %s", i, gotIP, wantIP) + } + case dnsmessage.TypeAAAA: + if ans.Body.(*dnsmessage.AAAAResource) == nil { + t.Errorf("answer[%d] not an AAAA record", i) + continue + } + resource := ans.Body.(*dnsmessage.AAAAResource) + gotIP := netip.AddrFrom16([16]byte(resource.AAAA)) - if gotIP != want.addr { - t.Errorf("answer[%d] IP = %s, want %s", i, gotIP, want.addr) + var ips []netip.Addr + if tc.wantIgnored { + ips = must.Get(c.resolver.LookupNetIP(t.Context(), "ip6", want.name)) + } else { + ips = must.Get(c.ipPool.IPForDomain(tailcfg.NodeID(123), want.name)) + } + var wantIP netip.Addr + for _, ip := range ips { + if ip.Is6() { + wantIP = ip + break + } + } + if gotIP != wantIP { + t.Errorf("answer[%d] IP = %s, want %s", i, gotIP, wantIP) } } } } } + + if tc.wantNXDOMAIN { + if msg.RCode != dnsmessage.RCodeNameError { + t.Errorf("expected NXDOMAIN, got %v", msg.RCode) + } + if len(msg.Answers) != 0 { + t.Errorf("expected no answers, got %d", len(msg.Answers)) + } + } }) } } @@ -257,53 +445,3 @@ func TestIgnoreDestination(t *testing.T) { }) } } - -func TestConnectorGenerateDNSResponse(t *testing.T) { - v6ULA := netip.MustParsePrefix("fd7a:115c:a1e0:a99c:0001::/80") - routes, dnsAddr, addrPool := calculateAddresses([]netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}) - c := &connector{ - v6ULA: v6ULA, - ipPool: &ippool.IPPool{V6ULA: v6ULA, IPSet: addrPool}, - routes: routes, - dnsAddr: dnsAddr, - } - - req := &dnsmessage.Message{ - Header: dnsmessage.Header{ID: 1234}, - Questions: []dnsmessage.Question{ - { - Name: dnsmessage.MustNewName("example.com."), - Type: dnsmessage.TypeA, - Class: dnsmessage.ClassINET, - }, - }, - } - - nodeID := tailcfg.NodeID(12345) - - resp1, err := c.generateDNSResponse(req, nodeID) - if err != nil { - t.Fatalf("generateDNSResponse() error = %v", err) - } - if len(resp1) == 0 { - t.Fatalf("generateDNSResponse() returned empty response") - } - - resp2, err := c.generateDNSResponse(req, nodeID) - if err != nil { - t.Fatalf("generateDNSResponse() second call error = %v", err) - } - - if !cmp.Equal(resp1, resp2) { - t.Errorf("generateDNSResponse() responses differ between calls") - } - - var msg dnsmessage.Message - err = msg.Unpack(resp1) - if err != nil { - t.Fatalf("dnsmessage Unpack error = %v", err) - } - if len(msg.Answers) != 1 { - t.Fatalf("expected 1 answer, got: %d", len(msg.Answers)) - } -}