net/socks5: fix UDP relay in userspace-networking mode

This commit addresses an issue with the SOCKS5 UDP relay functionality
when using the --tun=userspace-networking option. Previously, UDP packets
were not being correctly routed into the Tailscale network in this mode.

Key changes:
- Replace single UDP connection with a map of connections per target
- Use c.srv.dial for creating connections to ensure proper routing

Updates #7581

Change-Id: Iaaa66f9de6a3713218014cf3f498003a7cac9832
Signed-off-by: VimT <me@vimt.me>
This commit is contained in:
VimT 2024-09-20 23:52:45 +08:00 committed by Brad Fitzpatrick
parent 634cc2ba4a
commit b0626ff84c

View File

@ -22,6 +22,7 @@
"log" "log"
"net" "net"
"strconv" "strconv"
"tailscale.com/syncs"
"time" "time"
"tailscale.com/types/logger" "tailscale.com/types/logger"
@ -81,6 +82,12 @@
addrTypeNotSupported replyCode = 8 addrTypeNotSupported replyCode = 8
) )
// UDP conn default buffer size and read timeout.
const (
bufferSize = 8 * 1024
readTimeout = 5 * time.Second
)
// Server is a SOCKS5 proxy server. // Server is a SOCKS5 proxy server.
type Server struct { type Server struct {
// Logf optionally specifies the logger to use. // Logf optionally specifies the logger to use.
@ -144,6 +151,7 @@ type Conn struct {
request *request request *request
udpClientAddr net.Addr udpClientAddr net.Addr
udpTargetConns syncs.Map[string, net.Conn]
} }
// Run starts the new connection. // Run starts the new connection.
@ -276,15 +284,6 @@ func (c *Conn) handleUDP() error {
} }
defer clientUDPConn.Close() defer clientUDPConn.Close()
serverUDPConn, err := net.ListenPacket("udp", "[::]:0")
if err != nil {
res := errorResponse(generalFailure)
buf, _ := res.marshal()
c.clientConn.Write(buf)
return err
}
defer serverUDPConn.Close()
bindAddr, bindPort, err := splitHostPort(clientUDPConn.LocalAddr().String()) bindAddr, bindPort, err := splitHostPort(clientUDPConn.LocalAddr().String())
if err != nil { if err != nil {
return err return err
@ -305,14 +304,20 @@ func (c *Conn) handleUDP() error {
} }
c.clientConn.Write(buf) c.clientConn.Write(buf)
return c.transferUDP(c.clientConn, clientUDPConn, serverUDPConn) return c.transferUDP(c.clientConn, clientUDPConn)
} }
func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, targetConn net.PacketConn) error { func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) error {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
const bufferSize = 8 * 1024
const readTimeout = 5 * time.Second // close all target udp connections when the client connection is closed
defer func() {
c.udpTargetConns.Range(func(_ string, conn net.Conn) bool {
_ = conn.Close()
return true
})
}()
// client -> target // client -> target
go func() { go func() {
@ -323,7 +328,7 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, ta
case <-ctx.Done(): case <-ctx.Done():
return return
default: default:
err := c.handleUDPRequest(clientConn, targetConn, buf, readTimeout) err := c.handleUDPRequest(ctx, clientConn, buf)
if err != nil { if err != nil {
if isTimeout(err) { if isTimeout(err) {
continue continue
@ -337,29 +342,6 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, ta
} }
}() }()
// target -> client
go func() {
defer cancel()
buf := make([]byte, bufferSize)
for {
select {
case <-ctx.Done():
return
default:
err := c.handleUDPResponse(targetConn, clientConn, buf, readTimeout)
if err != nil {
if isTimeout(err) {
continue
}
if errors.Is(err, net.ErrClosed) {
return
}
c.logf("udp transfer: handle udp response fail: %v", err)
}
}
}
}()
// A UDP association terminates when the TCP connection that the UDP // A UDP association terminates when the TCP connection that the UDP
// ASSOCIATE request arrived on terminates. RFC1928 // ASSOCIATE request arrived on terminates. RFC1928
_, err := io.Copy(io.Discard, associatedTCP) _, err := io.Copy(io.Discard, associatedTCP)
@ -369,11 +351,56 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, ta
return err return err
} }
func (c *Conn) handleUDPRequest( func (c *Conn) getOrDialTargetConn(
ctx context.Context,
clientConn net.PacketConn,
targetAddr string,
) (net.Conn, error) {
host, port, err := splitHostPort(targetAddr)
if err != nil {
return nil, err
}
conn, loaded := c.udpTargetConns.Load(targetAddr)
if loaded {
return conn, nil
}
conn, err = c.srv.dial(ctx, "udp", targetAddr)
if err != nil {
return nil, err
}
c.udpTargetConns.Store(targetAddr, conn)
// target -> client
go func() {
buf := make([]byte, bufferSize)
addr := socksAddr{addrType: getAddrType(host), addr: host, port: port}
for {
select {
case <-ctx.Done():
return
default:
err := c.handleUDPResponse(clientConn, addr, conn, buf)
if err != nil {
if isTimeout(err) {
continue
}
if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
return
}
c.logf("udp transfer: handle udp response fail: %v", err)
}
}
}
}()
return conn, nil
}
func (c *Conn) handleUDPRequest(
ctx context.Context,
clientConn net.PacketConn, clientConn net.PacketConn,
targetConn net.PacketConn,
buf []byte, buf []byte,
readTimeout time.Duration,
) error { ) error {
// add a deadline for the read to avoid blocking forever // add a deadline for the read to avoid blocking forever
_ = clientConn.SetReadDeadline(time.Now().Add(readTimeout)) _ = clientConn.SetReadDeadline(time.Now().Add(readTimeout))
@ -386,12 +413,14 @@ func (c *Conn) handleUDPRequest(
if err != nil { if err != nil {
return fmt.Errorf("parse udp request: %w", err) return fmt.Errorf("parse udp request: %w", err)
} }
targetAddr, err := net.ResolveUDPAddr("udp", req.addr.hostPort())
targetAddr := req.addr.hostPort()
targetConn, err := c.getOrDialTargetConn(ctx, clientConn, targetAddr)
if err != nil { if err != nil {
c.logf("resolve target addr fail: %v", err) return fmt.Errorf("dial target %s fail: %w", targetAddr, err)
} }
nn, err := targetConn.WriteTo(data, targetAddr) nn, err := targetConn.Write(data)
if err != nil { if err != nil {
return fmt.Errorf("write to target %s fail: %w", targetAddr, err) return fmt.Errorf("write to target %s fail: %w", targetAddr, err)
} }
@ -402,22 +431,18 @@ func (c *Conn) handleUDPRequest(
} }
func (c *Conn) handleUDPResponse( func (c *Conn) handleUDPResponse(
targetConn net.PacketConn,
clientConn net.PacketConn, clientConn net.PacketConn,
targetAddr socksAddr,
targetConn net.Conn,
buf []byte, buf []byte,
readTimeout time.Duration,
) error { ) error {
// add a deadline for the read to avoid blocking forever // add a deadline for the read to avoid blocking forever
_ = targetConn.SetReadDeadline(time.Now().Add(readTimeout)) _ = targetConn.SetReadDeadline(time.Now().Add(readTimeout))
n, addr, err := targetConn.ReadFrom(buf) n, err := targetConn.Read(buf)
if err != nil { if err != nil {
return fmt.Errorf("read from target: %w", err) return fmt.Errorf("read from target: %w", err)
} }
host, port, err := splitHostPort(addr.String()) hdr := udpRequest{addr: targetAddr}
if err != nil {
return fmt.Errorf("split host port: %w", err)
}
hdr := udpRequest{addr: socksAddr{addrType: getAddrType(host), addr: host, port: port}}
pkt, err := hdr.marshal() pkt, err := hdr.marshal()
if err != nil { if err != nil {
return fmt.Errorf("marshal udp request: %w", err) return fmt.Errorf("marshal udp request: %w", err)