cmd/natc: add a flag to use specific DNS servers

If natc is running on a host with tailscale using `--accept-dns=true`
then a DNS loop can occur. Provide a flag for some specific DNS
upstreams for natc to use instead, to overcome such situations.

Updates #14667

Signed-off-by: James Tucker <james@tailscale.com>
This commit is contained in:
James Tucker 2025-06-16 10:27:00 -07:00 committed by James Tucker
parent 735f15cb49
commit 86985228bc
2 changed files with 225 additions and 2 deletions

View File

@ -54,6 +54,7 @@ func main() {
hostname = fs.String("hostname", "", "Hostname to register the service under") hostname = fs.String("hostname", "", "Hostname to register the service under")
siteID = fs.Uint("site-id", 1, "an integer site ID to use for the ULA prefix which allows for multiple proxies to act in a HA configuration") siteID = fs.Uint("site-id", 1, "an integer site ID to use for the ULA prefix which allows for multiple proxies to act in a HA configuration")
v4PfxStr = fs.String("v4-pfx", "100.64.1.0/24", "comma-separated list of IPv4 prefixes to advertise") v4PfxStr = fs.String("v4-pfx", "100.64.1.0/24", "comma-separated list of IPv4 prefixes to advertise")
dnsServers = fs.String("dns-servers", "", "comma separated list of upstream DNS to use, including host and port (use system if empty)")
verboseTSNet = fs.Bool("verbose-tsnet", false, "enable verbose logging in tsnet") verboseTSNet = fs.Bool("verbose-tsnet", false, "enable verbose logging in tsnet")
printULA = fs.Bool("print-ula", false, "print the ULA prefix and exit") printULA = fs.Bool("print-ula", false, "print the ULA prefix and exit")
ignoreDstPfxStr = fs.String("ignore-destinations", "", "comma-separated list of prefixes to ignore") ignoreDstPfxStr = fs.String("ignore-destinations", "", "comma-separated list of prefixes to ignore")
@ -78,7 +79,7 @@ func main() {
} }
var ignoreDstTable *bart.Table[bool] var ignoreDstTable *bart.Table[bool]
for _, s := range strings.Split(*ignoreDstPfxStr, ",") { for s := range strings.SplitSeq(*ignoreDstPfxStr, ",") {
s := strings.TrimSpace(s) s := strings.TrimSpace(s)
if s == "" { if s == "" {
continue continue
@ -185,11 +186,37 @@ func main() {
ipPool: ipp, ipPool: ipp,
routes: routes, routes: routes,
dnsAddr: dnsAddr, dnsAddr: dnsAddr,
resolver: net.DefaultResolver, resolver: getResolver(*dnsServers),
} }
c.run(ctx, lc) c.run(ctx, lc)
} }
// getResolver parses serverFlag and returns either the default resolver, or a
// resolver that uses the provided comma-separated DNS server AddrPort's, or
// panics.
func getResolver(serverFlag string) lookupNetIPer {
if serverFlag == "" {
return net.DefaultResolver
}
var addrs []string
for s := range strings.SplitSeq(serverFlag, ",") {
s = strings.TrimSpace(s)
addr, err := netip.ParseAddrPort(s)
if err != nil {
log.Fatalf("dns server provided: %q does not parse: %v", s, err)
}
addrs = append(addrs, addr.String())
}
return &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network string, address string) (net.Conn, error) {
var dialer net.Dialer
// TODO(raggi): perhaps something other than random?
return dialer.DialContext(ctx, network, addrs[rand.N(len(addrs))])
},
}
}
func calculateAddresses(prefixes []netip.Prefix) (*netipx.IPSet, netip.Addr, *netipx.IPSet) { func calculateAddresses(prefixes []netip.Prefix) (*netipx.IPSet, netip.Addr, *netipx.IPSet) {
var ipsb netipx.IPSetBuilder var ipsb netipx.IPSetBuilder
for _, p := range prefixes { for _, p := range prefixes {

View File

@ -9,6 +9,7 @@ import (
"io" "io"
"net" "net"
"net/netip" "net/netip"
"sync"
"testing" "testing"
"time" "time"
@ -480,3 +481,198 @@ func TestV6V4(t *testing.T) {
} }
} }
} }
// echoServer is a simple server that just echos back data set to it.
type echoServer struct {
listener net.Listener
addr string
wg sync.WaitGroup
done chan struct{}
}
// newEchoServer creates a new test DNS server on the specified network and address
func newEchoServer(t *testing.T, network, addr string) *echoServer {
listener, err := net.Listen(network, addr)
if err != nil {
t.Fatalf("Failed to create test DNS server: %v", err)
}
server := &echoServer{
listener: listener,
addr: listener.Addr().String(),
done: make(chan struct{}),
}
server.wg.Add(1)
go server.serve()
return server
}
func (s *echoServer) serve() {
defer s.wg.Done()
for {
select {
case <-s.done:
return
default:
conn, err := s.listener.Accept()
if err != nil {
select {
case <-s.done:
return
default:
continue
}
}
go s.handleConnection(conn)
}
}
}
func (s *echoServer) handleConnection(conn net.Conn) {
defer conn.Close()
// Simple response - just echo back some data to confirm connectivity
buf := make([]byte, 1024)
n, err := conn.Read(buf)
if err != nil {
return
}
conn.Write(buf[:n])
}
func (s *echoServer) close() {
close(s.done)
s.listener.Close()
s.wg.Wait()
}
func TestGetResolver(t *testing.T) {
tests := []struct {
name string
network string
addr string
}{
{
name: "ipv4_loopback",
network: "tcp4",
addr: "127.0.0.1:0",
},
{
name: "ipv6_loopback",
network: "tcp6",
addr: "[::1]:0",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
server := newEchoServer(t, tc.network, tc.addr)
defer server.close()
serverAddr := server.addr
resolver := getResolver(serverAddr)
if resolver == nil {
t.Fatal("getResolver returned nil")
}
netResolver, ok := resolver.(*net.Resolver)
if !ok {
t.Fatal("getResolver did not return a *net.Resolver")
}
if netResolver.Dial == nil {
t.Fatal("resolver.Dial is nil")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
conn, err := netResolver.Dial(ctx, "tcp", "dummy.address:53")
if err != nil {
t.Fatalf("Failed to dial test DNS server: %v", err)
}
defer conn.Close()
testData := []byte("test")
_, err = conn.Write(testData)
if err != nil {
t.Fatalf("Failed to write to connection: %v", err)
}
response := make([]byte, len(testData))
_, err = conn.Read(response)
if err != nil {
t.Fatalf("Failed to read from connection: %v", err)
}
if string(response) != string(testData) {
t.Fatalf("Expected echo response %q, got %q", testData, response)
}
})
}
}
func TestGetResolverMultipleServers(t *testing.T) {
server1 := newEchoServer(t, "tcp4", "127.0.0.1:0")
defer server1.close()
server2 := newEchoServer(t, "tcp4", "127.0.0.1:0")
defer server2.close()
serverFlag := server1.addr + ", " + server2.addr
resolver := getResolver(serverFlag)
netResolver, ok := resolver.(*net.Resolver)
if !ok {
t.Fatal("getResolver did not return a *net.Resolver")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
servers := map[string]bool{
server1.addr: false,
server2.addr: false,
}
// Try up to 1000 times to hit all servers, this should be very quick, and
// if this fails randomness has regressed beyond reason.
for range 1000 {
conn, err := netResolver.Dial(ctx, "tcp", "dummy.address:53")
if err != nil {
t.Fatalf("Failed to dial test DNS server: %v", err)
}
remoteAddr := conn.RemoteAddr().String()
conn.Close()
servers[remoteAddr] = true
var allDone = true
for _, done := range servers {
if !done {
allDone = false
break
}
}
if allDone {
break
}
}
var allDone = true
for _, done := range servers {
if !done {
allDone = false
break
}
}
if !allDone {
t.Errorf("after 1000 queries, not all servers were hit, significant lack of randomness: %#v", servers)
}
}
func TestGetResolverEmpty(t *testing.T) {
resolver := getResolver("")
if resolver != net.DefaultResolver {
t.Fatal(`getResolver("") should return net.DefaultResolver`)
}
}