mirror of
https://github.com/tailscale/tailscale.git
synced 2025-08-13 22:47:30 +00:00
net/socks5: optimize UDP relay
Key changes: - No mutex for every udp package: replace syncs.Map with regular map for udpTargetConns - Use socksAddr as map key for better type safety - Add test for multi udp target Updates #7581 Change-Id: Ic3d384a9eab62dcbf267d7d6d268bf242cc8ed3c Signed-off-by: VimT <me@vimt.me>
This commit is contained in:
@@ -22,7 +22,6 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
"tailscale.com/syncs"
|
||||
"time"
|
||||
|
||||
"tailscale.com/types/logger"
|
||||
@@ -151,7 +150,7 @@ type Conn struct {
|
||||
request *request
|
||||
|
||||
udpClientAddr net.Addr
|
||||
udpTargetConns syncs.Map[string, net.Conn]
|
||||
udpTargetConns map[socksAddr]net.Conn
|
||||
}
|
||||
|
||||
// Run starts the new connection.
|
||||
@@ -311,17 +310,18 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) er
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// close all target udp connections when the client connection is closed
|
||||
defer func() {
|
||||
c.udpTargetConns.Range(func(_ string, conn net.Conn) bool {
|
||||
_ = conn.Close()
|
||||
return true
|
||||
})
|
||||
}()
|
||||
|
||||
// client -> target
|
||||
go func() {
|
||||
defer cancel()
|
||||
|
||||
c.udpTargetConns = make(map[socksAddr]net.Conn)
|
||||
// close all target udp connections when the client connection is closed
|
||||
defer func() {
|
||||
for _, conn := range c.udpTargetConns {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, bufferSize)
|
||||
for {
|
||||
select {
|
||||
@@ -354,33 +354,27 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) er
|
||||
func (c *Conn) getOrDialTargetConn(
|
||||
ctx context.Context,
|
||||
clientConn net.PacketConn,
|
||||
targetAddr string,
|
||||
targetAddr socksAddr,
|
||||
) (net.Conn, error) {
|
||||
host, port, err := splitHostPort(targetAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn, loaded := c.udpTargetConns.Load(targetAddr)
|
||||
if loaded {
|
||||
conn, exist := c.udpTargetConns[targetAddr]
|
||||
if exist {
|
||||
return conn, nil
|
||||
}
|
||||
conn, err = c.srv.dial(ctx, "udp", targetAddr)
|
||||
conn, err := c.srv.dial(ctx, "udp", targetAddr.hostPort())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.udpTargetConns.Store(targetAddr, conn)
|
||||
c.udpTargetConns[targetAddr] = conn
|
||||
|
||||
// target -> client
|
||||
go func() {
|
||||
buf := make([]byte, bufferSize)
|
||||
addr := socksAddr{addrType: getAddrType(host), addr: host, port: port}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
err := c.handleUDPResponse(clientConn, addr, conn, buf)
|
||||
err := c.handleUDPResponse(clientConn, targetAddr, conn, buf)
|
||||
if err != nil {
|
||||
if isTimeout(err) {
|
||||
continue
|
||||
@@ -414,18 +408,17 @@ func (c *Conn) handleUDPRequest(
|
||||
return fmt.Errorf("parse udp request: %w", err)
|
||||
}
|
||||
|
||||
targetAddr := req.addr.hostPort()
|
||||
targetConn, err := c.getOrDialTargetConn(ctx, clientConn, targetAddr)
|
||||
targetConn, err := c.getOrDialTargetConn(ctx, clientConn, req.addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial target %s fail: %w", targetAddr, err)
|
||||
return fmt.Errorf("dial target %s fail: %w", req.addr, err)
|
||||
}
|
||||
|
||||
nn, err := targetConn.Write(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write to target %s fail: %w", targetAddr, err)
|
||||
return fmt.Errorf("write to target %s fail: %w", req.addr, err)
|
||||
}
|
||||
if nn != len(data) {
|
||||
return fmt.Errorf("write to target %s fail: %w", targetAddr, io.ErrShortWrite)
|
||||
return fmt.Errorf("write to target %s fail: %w", req.addr, io.ErrShortWrite)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -652,10 +645,15 @@ func (s socksAddr) marshal() ([]byte, error) {
|
||||
pkt = binary.BigEndian.AppendUint16(pkt, s.port)
|
||||
return pkt, nil
|
||||
}
|
||||
|
||||
func (s socksAddr) hostPort() string {
|
||||
return net.JoinHostPort(s.addr, strconv.Itoa(int(s.port)))
|
||||
}
|
||||
|
||||
func (s socksAddr) String() string {
|
||||
return s.hostPort()
|
||||
}
|
||||
|
||||
// response contains the contents of
|
||||
// a response packet sent from the proxy
|
||||
// to the client.
|
||||
|
Reference in New Issue
Block a user