From 37df61fb4e518abec842856bbb34b3b10ecba49d Mon Sep 17 00:00:00 2001 From: James Tucker Date: Thu, 20 Mar 2025 11:12:20 -0700 Subject: [PATCH] cmd/natc: add some initial unit test coverage These tests aren't perfect, nor is this complete coverage, but this is a set of coverage that is at least stable. Updates #15367 Signed-off-by: James Tucker --- cmd/natc/natc_test.go | 365 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 365 insertions(+) create mode 100644 cmd/natc/natc_test.go diff --git a/cmd/natc/natc_test.go b/cmd/natc/natc_test.go new file mode 100644 index 000000000..1b6d7af7c --- /dev/null +++ b/cmd/natc/natc_test.go @@ -0,0 +1,365 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "net/netip" + "slices" + "testing" + + "github.com/gaissmai/bart" + "github.com/google/go-cmp/cmp" + "golang.org/x/net/dns/dnsmessage" + "tailscale.com/tailcfg" +) + +func prefixEqual(a, b netip.Prefix) bool { + return a.Bits() == b.Bits() && a.Addr() == b.Addr() +} + +func TestULA(t *testing.T) { + tests := []struct { + name string + siteID uint16 + expected string + }{ + {"zero", 0, "fd7a:115c:a1e0:a99c:0000::/80"}, + {"one", 1, "fd7a:115c:a1e0:a99c:0001::/80"}, + {"max", 65535, "fd7a:115c:a1e0:a99c:ffff::/80"}, + {"random", 12345, "fd7a:115c:a1e0:a99c:3039::/80"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := ula(tc.siteID) + expected := netip.MustParsePrefix(tc.expected) + if !prefixEqual(got, expected) { + t.Errorf("ula(%d) = %s; want %s", tc.siteID, got, expected) + } + }) + } +} + +func TestRandV4(t *testing.T) { + pfx := netip.MustParsePrefix("100.64.1.0/24") + + for i := 0; i < 512; i++ { + ip := randV4(pfx) + if !pfx.Contains(ip) { + t.Errorf("randV4(%s) = %s; not contained in prefix", pfx, ip) + } + } +} + +func TestDNSResponse(t *testing.T) { + tests := []struct { + name string + questions []dnsmessage.Question + addrs []netip.Addr + wantEmpty bool + wantAnswers []struct { + name string + qType dnsmessage.Type + addr netip.Addr + } + }{ + { + name: "empty_request", + questions: []dnsmessage.Question{}, + addrs: []netip.Addr{}, + wantEmpty: false, + wantAnswers: nil, + }, + { + name: "a_record", + questions: []dnsmessage.Question{ + { + Name: dnsmessage.MustNewName("example.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + }, + addrs: []netip.Addr{netip.MustParseAddr("100.64.1.5")}, + wantAnswers: []struct { + name string + qType dnsmessage.Type + addr netip.Addr + }{ + { + name: "example.com.", + qType: dnsmessage.TypeA, + addr: netip.MustParseAddr("100.64.1.5"), + }, + }, + }, + { + name: "aaaa_record", + questions: []dnsmessage.Question{ + { + Name: dnsmessage.MustNewName("example.com."), + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + }, + }, + addrs: []netip.Addr{netip.MustParseAddr("fd7a:115c:a1e0:a99c:0001:0505:0505:0505")}, + wantAnswers: []struct { + name string + qType dnsmessage.Type + addr netip.Addr + }{ + { + name: "example.com.", + qType: dnsmessage.TypeAAAA, + addr: netip.MustParseAddr("fd7a:115c:a1e0:a99c:0001:0505:0505:0505"), + }, + }, + }, + { + name: "soa_record", + questions: []dnsmessage.Question{ + { + Name: dnsmessage.MustNewName("example.com."), + Type: dnsmessage.TypeSOA, + Class: dnsmessage.ClassINET, + }, + }, + addrs: []netip.Addr{}, + wantAnswers: nil, + }, + { + name: "ns_record", + questions: []dnsmessage.Question{ + { + Name: dnsmessage.MustNewName("example.com."), + Type: dnsmessage.TypeNS, + Class: dnsmessage.ClassINET, + }, + }, + addrs: []netip.Addr{}, + wantAnswers: nil, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := &dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: 1234, + }, + Questions: tc.questions, + } + + resp, err := dnsResponse(req, tc.addrs) + if err != nil { + t.Fatalf("dnsResponse() error = %v", err) + } + + if tc.wantEmpty && len(resp) != 0 { + t.Errorf("dnsResponse() returned non-empty response when expected empty") + } + + if !tc.wantEmpty && len(resp) == 0 { + t.Errorf("dnsResponse() returned empty response when expected non-empty") + } + + if len(resp) > 0 { + var msg dnsmessage.Message + err = msg.Unpack(resp) + if err != nil { + t.Fatalf("Failed to unpack response: %v", err) + } + + if !msg.Header.Response { + t.Errorf("Response header is not set") + } + + if msg.Header.ID != req.Header.ID { + t.Errorf("Response ID = %d, want %d", msg.Header.ID, req.Header.ID) + } + + if len(tc.wantAnswers) > 0 { + if len(msg.Answers) != len(tc.wantAnswers) { + t.Errorf("got %d answers, want %d", len(msg.Answers), len(tc.wantAnswers)) + } else { + for i, want := range tc.wantAnswers { + ans := msg.Answers[i] + + gotName := ans.Header.Name.String() + if gotName != want.name { + t.Errorf("answer[%d] name = %s, want %s", i, gotName, want.name) + } + + if ans.Header.Type != want.qType { + t.Errorf("answer[%d] type = %v, want %v", i, ans.Header.Type, want.qType) + } + + var gotIP netip.Addr + switch want.qType { + case dnsmessage.TypeA: + if ans.Body.(*dnsmessage.AResource) == nil { + t.Errorf("answer[%d] not an A record", i) + continue + } + resource := ans.Body.(*dnsmessage.AResource) + gotIP = netip.AddrFrom4([4]byte(resource.A)) + case dnsmessage.TypeAAAA: + if ans.Body.(*dnsmessage.AAAAResource) == nil { + t.Errorf("answer[%d] not an AAAA record", i) + continue + } + resource := ans.Body.(*dnsmessage.AAAAResource) + gotIP = netip.AddrFrom16([16]byte(resource.AAAA)) + } + + if gotIP != want.addr { + t.Errorf("answer[%d] IP = %s, want %s", i, gotIP, want.addr) + } + } + } + } + } + }) + } +} + +func TestPerPeerState(t *testing.T) { + c := &connector{ + v4Ranges: []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}, + v6ULA: netip.MustParsePrefix("fd7a:115c:a1e0:a99c:0001::/80"), + dnsAddr: netip.MustParseAddr("100.64.1.1"), + } + + ps := &perPeerState{c: c} + + addrs, err := ps.ipForDomain("example.com") + if err != nil { + t.Fatalf("ipForDomain() error = %v", err) + } + + if len(addrs) != 2 { + t.Fatalf("ipForDomain() returned %d addresses, want 2", len(addrs)) + } + + v4 := addrs[0] + v6 := addrs[1] + + if !v4.Is4() { + t.Errorf("First address is not IPv4: %s", v4) + } + + if !v6.Is6() { + t.Errorf("Second address is not IPv6: %s", v6) + } + + if !c.v4Ranges[0].Contains(v4) { + t.Errorf("IPv4 address %s not in range %s", v4, c.v4Ranges[0]) + } + + domain, ok := ps.domainForIP(v4) + if !ok { + t.Errorf("domainForIP(%s) not found", v4) + } else if domain != "example.com" { + t.Errorf("domainForIP(%s) = %s, want %s", v4, domain, "example.com") + } + + domain, ok = ps.domainForIP(v6) + if !ok { + t.Errorf("domainForIP(%s) not found", v6) + } else if domain != "example.com" { + t.Errorf("domainForIP(%s) = %s, want %s", v6, domain, "example.com") + } + + addrs2, err := ps.ipForDomain("example.com") + if err != nil { + t.Fatalf("ipForDomain() second call error = %v", err) + } + + if !slices.Equal(addrs, addrs2) { + t.Errorf("ipForDomain() second call = %v, want %v", addrs2, addrs) + } +} + +func TestIgnoreDestination(t *testing.T) { + ignoreDstTable := &bart.Table[bool]{} + ignoreDstTable.Insert(netip.MustParsePrefix("192.168.1.0/24"), true) + ignoreDstTable.Insert(netip.MustParsePrefix("10.0.0.0/8"), true) + + c := &connector{ + ignoreDsts: ignoreDstTable, + } + + tests := []struct { + name string + addrs []netip.Addr + expected bool + }{ + { + name: "no_match", + addrs: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")}, + expected: false, + }, + { + name: "one_match", + addrs: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("192.168.1.5")}, + expected: true, + }, + { + name: "all_match", + addrs: []netip.Addr{netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("192.168.1.5")}, + expected: true, + }, + { + name: "empty_addrs", + addrs: []netip.Addr{}, + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := c.ignoreDestination(tc.addrs) + if got != tc.expected { + t.Errorf("ignoreDestination(%v) = %v, want %v", tc.addrs, got, tc.expected) + } + }) + } +} + +func TestConnectorGenerateDNSResponse(t *testing.T) { + c := &connector{ + v4Ranges: []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}, + v6ULA: netip.MustParsePrefix("fd7a:115c:a1e0:a99c:0001::/80"), + dnsAddr: netip.MustParseAddr("100.64.1.1"), + } + + req := &dnsmessage.Message{ + Header: dnsmessage.Header{ID: 1234}, + Questions: []dnsmessage.Question{ + { + Name: dnsmessage.MustNewName("example.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + }, + } + + nodeID := tailcfg.NodeID(12345) + + resp1, err := c.generateDNSResponse(req, nodeID) + if err != nil { + t.Fatalf("generateDNSResponse() error = %v", err) + } + if len(resp1) == 0 { + t.Fatalf("generateDNSResponse() returned empty response") + } + + resp2, err := c.generateDNSResponse(req, nodeID) + if err != nil { + t.Fatalf("generateDNSResponse() second call error = %v", err) + } + + if !cmp.Equal(resp1, resp2) { + t.Errorf("generateDNSResponse() responses differ between calls") + } +}