From 86985228bcef855a8071f6989bbceeb5b21810c2 Mon Sep 17 00:00:00 2001 From: James Tucker Date: Mon, 16 Jun 2025 10:27:00 -0700 Subject: [PATCH] cmd/natc: add a flag to use specific DNS servers If natc is running on a host with tailscale using `--accept-dns=true` then a DNS loop can occur. Provide a flag for some specific DNS upstreams for natc to use instead, to overcome such situations. Updates #14667 Signed-off-by: James Tucker --- cmd/natc/natc.go | 31 ++++++- cmd/natc/natc_test.go | 196 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 225 insertions(+), 2 deletions(-) diff --git a/cmd/natc/natc.go b/cmd/natc/natc.go index 247bb2101..fdbce3da1 100644 --- a/cmd/natc/natc.go +++ b/cmd/natc/natc.go @@ -54,6 +54,7 @@ func main() { hostname = fs.String("hostname", "", "Hostname to register the service under") siteID = fs.Uint("site-id", 1, "an integer site ID to use for the ULA prefix which allows for multiple proxies to act in a HA configuration") v4PfxStr = fs.String("v4-pfx", "100.64.1.0/24", "comma-separated list of IPv4 prefixes to advertise") + dnsServers = fs.String("dns-servers", "", "comma separated list of upstream DNS to use, including host and port (use system if empty)") verboseTSNet = fs.Bool("verbose-tsnet", false, "enable verbose logging in tsnet") printULA = fs.Bool("print-ula", false, "print the ULA prefix and exit") ignoreDstPfxStr = fs.String("ignore-destinations", "", "comma-separated list of prefixes to ignore") @@ -78,7 +79,7 @@ func main() { } var ignoreDstTable *bart.Table[bool] - for _, s := range strings.Split(*ignoreDstPfxStr, ",") { + for s := range strings.SplitSeq(*ignoreDstPfxStr, ",") { s := strings.TrimSpace(s) if s == "" { continue @@ -185,11 +186,37 @@ func main() { ipPool: ipp, routes: routes, dnsAddr: dnsAddr, - resolver: net.DefaultResolver, + resolver: getResolver(*dnsServers), } c.run(ctx, lc) } +// getResolver parses serverFlag and returns either the default resolver, or a +// resolver that uses the provided comma-separated DNS server AddrPort's, or +// panics. +func getResolver(serverFlag string) lookupNetIPer { + if serverFlag == "" { + return net.DefaultResolver + } + var addrs []string + for s := range strings.SplitSeq(serverFlag, ",") { + s = strings.TrimSpace(s) + addr, err := netip.ParseAddrPort(s) + if err != nil { + log.Fatalf("dns server provided: %q does not parse: %v", s, err) + } + addrs = append(addrs, addr.String()) + } + return &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network string, address string) (net.Conn, error) { + var dialer net.Dialer + // TODO(raggi): perhaps something other than random? + return dialer.DialContext(ctx, network, addrs[rand.N(len(addrs))]) + }, + } +} + func calculateAddresses(prefixes []netip.Prefix) (*netipx.IPSet, netip.Addr, *netipx.IPSet) { var ipsb netipx.IPSetBuilder for _, p := range prefixes { diff --git a/cmd/natc/natc_test.go b/cmd/natc/natc_test.go index 78dec86fd..c0a66deb8 100644 --- a/cmd/natc/natc_test.go +++ b/cmd/natc/natc_test.go @@ -9,6 +9,7 @@ import ( "io" "net" "net/netip" + "sync" "testing" "time" @@ -480,3 +481,198 @@ func TestV6V4(t *testing.T) { } } } + +// echoServer is a simple server that just echos back data set to it. +type echoServer struct { + listener net.Listener + addr string + wg sync.WaitGroup + done chan struct{} +} + +// newEchoServer creates a new test DNS server on the specified network and address +func newEchoServer(t *testing.T, network, addr string) *echoServer { + listener, err := net.Listen(network, addr) + if err != nil { + t.Fatalf("Failed to create test DNS server: %v", err) + } + + server := &echoServer{ + listener: listener, + addr: listener.Addr().String(), + done: make(chan struct{}), + } + + server.wg.Add(1) + go server.serve() + + return server +} + +func (s *echoServer) serve() { + defer s.wg.Done() + + for { + select { + case <-s.done: + return + default: + conn, err := s.listener.Accept() + if err != nil { + select { + case <-s.done: + return + default: + continue + } + } + go s.handleConnection(conn) + } + } +} + +func (s *echoServer) handleConnection(conn net.Conn) { + defer conn.Close() + // Simple response - just echo back some data to confirm connectivity + buf := make([]byte, 1024) + n, err := conn.Read(buf) + if err != nil { + return + } + conn.Write(buf[:n]) +} + +func (s *echoServer) close() { + close(s.done) + s.listener.Close() + s.wg.Wait() +} + +func TestGetResolver(t *testing.T) { + tests := []struct { + name string + network string + addr string + }{ + { + name: "ipv4_loopback", + network: "tcp4", + addr: "127.0.0.1:0", + }, + { + name: "ipv6_loopback", + network: "tcp6", + addr: "[::1]:0", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + server := newEchoServer(t, tc.network, tc.addr) + defer server.close() + serverAddr := server.addr + resolver := getResolver(serverAddr) + if resolver == nil { + t.Fatal("getResolver returned nil") + } + + netResolver, ok := resolver.(*net.Resolver) + if !ok { + t.Fatal("getResolver did not return a *net.Resolver") + } + if netResolver.Dial == nil { + t.Fatal("resolver.Dial is nil") + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + conn, err := netResolver.Dial(ctx, "tcp", "dummy.address:53") + if err != nil { + t.Fatalf("Failed to dial test DNS server: %v", err) + } + defer conn.Close() + + testData := []byte("test") + _, err = conn.Write(testData) + if err != nil { + t.Fatalf("Failed to write to connection: %v", err) + } + + response := make([]byte, len(testData)) + _, err = conn.Read(response) + if err != nil { + t.Fatalf("Failed to read from connection: %v", err) + } + + if string(response) != string(testData) { + t.Fatalf("Expected echo response %q, got %q", testData, response) + } + }) + } +} + +func TestGetResolverMultipleServers(t *testing.T) { + server1 := newEchoServer(t, "tcp4", "127.0.0.1:0") + defer server1.close() + server2 := newEchoServer(t, "tcp4", "127.0.0.1:0") + defer server2.close() + serverFlag := server1.addr + ", " + server2.addr + + resolver := getResolver(serverFlag) + netResolver, ok := resolver.(*net.Resolver) + if !ok { + t.Fatal("getResolver did not return a *net.Resolver") + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + servers := map[string]bool{ + server1.addr: false, + server2.addr: false, + } + + // Try up to 1000 times to hit all servers, this should be very quick, and + // if this fails randomness has regressed beyond reason. + for range 1000 { + conn, err := netResolver.Dial(ctx, "tcp", "dummy.address:53") + if err != nil { + t.Fatalf("Failed to dial test DNS server: %v", err) + } + + remoteAddr := conn.RemoteAddr().String() + + conn.Close() + + servers[remoteAddr] = true + + var allDone = true + for _, done := range servers { + if !done { + allDone = false + break + } + } + if allDone { + break + } + } + + var allDone = true + for _, done := range servers { + if !done { + allDone = false + break + } + } + if !allDone { + t.Errorf("after 1000 queries, not all servers were hit, significant lack of randomness: %#v", servers) + } +} + +func TestGetResolverEmpty(t *testing.T) { + resolver := getResolver("") + if resolver != net.DefaultResolver { + t.Fatal(`getResolver("") should return net.DefaultResolver`) + } +}