diff --git a/cmd/natc/natc.go b/cmd/natc/natc.go index a80e4a42a..585a0bb45 100644 --- a/cmd/natc/natc.go +++ b/cmd/natc/natc.go @@ -13,6 +13,7 @@ import ( "flag" "fmt" "log" + "math/rand/v2" "net" "net/http" "net/netip" @@ -438,7 +439,7 @@ func (c *connector) handleTCPFlow(src, dst netip.AddrPort) (handler func(net.Con return nil, false } return func(conn net.Conn) { - proxyTCPConn(conn, domain) + proxyTCPConn(conn, domain, c) }, true } @@ -456,16 +457,34 @@ func (c *connector) ignoreDestination(dstAddrs []netip.Addr) bool { return false } -func proxyTCPConn(c net.Conn, dest string) { +func proxyTCPConn(c net.Conn, dest string, ctor *connector) { if c.RemoteAddr() == nil { log.Printf("proxyTCPConn: nil RemoteAddr") c.Close() return } - addrPortStr := c.LocalAddr().String() - _, port, err := net.SplitHostPort(addrPortStr) + laddr, err := netip.ParseAddrPort(c.LocalAddr().String()) if err != nil { - log.Printf("tcpRoundRobinHandler.Handle: bogus addrPort %q", addrPortStr) + log.Printf("proxyTCPConn: ParseAddrPort failed: %v", err) + c.Close() + return + } + + daddrs, err := ctor.resolver.LookupNetIP(context.TODO(), "ip", dest) + if err != nil { + log.Printf("proxyTCPConn: LookupNetIP failed: %v", err) + c.Close() + return + } + + if len(daddrs) == 0 { + log.Printf("proxyTCPConn: no IP addresses found for %s", dest) + c.Close() + return + } + + if ctor.ignoreDestination(daddrs) { + log.Printf("proxyTCPConn: closing connection to ignored destination %s (%v)", dest, daddrs) c.Close() return } @@ -475,10 +494,37 @@ func proxyTCPConn(c net.Conn, dest string) { return netutil.NewOneConnListener(c, nil), nil }, } - // XXX(raggi): if the connection here resolves to an ignored destination, - // the connection should be closed/failed. - p.AddRoute(addrPortStr, &tcpproxy.DialProxy{ - Addr: fmt.Sprintf("%s:%s", dest, port), + + // TODO(raggi): more code could avoid this shuffle, but avoiding allocations + // for now most of the time daddrs will be short. + rand.Shuffle(len(daddrs), func(i, j int) { + daddrs[i], daddrs[j] = daddrs[j], daddrs[i] }) + daddr := daddrs[0] + + // Try to match the upstream and downstream protocols (v4/v6) + if laddr.Addr().Is6() { + for _, addr := range daddrs { + if addr.Is6() { + daddr = addr + break + } + } + } else { + for _, addr := range daddrs { + if addr.Is4() { + daddr = addr + break + } + } + } + + // TODO(raggi): drop this library, it ends up being allocation and + // indirection heavy and really doesn't help us here. + dsockaddrs := netip.AddrPortFrom(daddr, laddr.Port()).String() + p.AddRoute(dsockaddrs, &tcpproxy.DialProxy{ + Addr: dsockaddrs, + }) + p.Start() }