mirror of
https://github.com/tailscale/tailscale.git
synced 2025-04-16 03:31:39 +00:00
net/socks5: optimize UDP relay
Key changes: - No mutex for every udp package: replace syncs.Map with regular map for udpTargetConns - Use socksAddr as map key for better type safety - Add test for multi udp target Updates #7581 Change-Id: Ic3d384a9eab62dcbf267d7d6d268bf242cc8ed3c Signed-off-by: VimT <me@vimt.me>
This commit is contained in:
parent
b0626ff84c
commit
43138c7a5c
@ -22,7 +22,6 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"tailscale.com/syncs"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"tailscale.com/types/logger"
|
"tailscale.com/types/logger"
|
||||||
@ -151,7 +150,7 @@ type Conn struct {
|
|||||||
request *request
|
request *request
|
||||||
|
|
||||||
udpClientAddr net.Addr
|
udpClientAddr net.Addr
|
||||||
udpTargetConns syncs.Map[string, net.Conn]
|
udpTargetConns map[socksAddr]net.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run starts the new connection.
|
// Run starts the new connection.
|
||||||
@ -311,17 +310,18 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) er
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// close all target udp connections when the client connection is closed
|
|
||||||
defer func() {
|
|
||||||
c.udpTargetConns.Range(func(_ string, conn net.Conn) bool {
|
|
||||||
_ = conn.Close()
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
}()
|
|
||||||
|
|
||||||
// client -> target
|
// client -> target
|
||||||
go func() {
|
go func() {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
c.udpTargetConns = make(map[socksAddr]net.Conn)
|
||||||
|
// close all target udp connections when the client connection is closed
|
||||||
|
defer func() {
|
||||||
|
for _, conn := range c.udpTargetConns {
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
buf := make([]byte, bufferSize)
|
buf := make([]byte, bufferSize)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@ -354,33 +354,27 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) er
|
|||||||
func (c *Conn) getOrDialTargetConn(
|
func (c *Conn) getOrDialTargetConn(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
clientConn net.PacketConn,
|
clientConn net.PacketConn,
|
||||||
targetAddr string,
|
targetAddr socksAddr,
|
||||||
) (net.Conn, error) {
|
) (net.Conn, error) {
|
||||||
host, port, err := splitHostPort(targetAddr)
|
conn, exist := c.udpTargetConns[targetAddr]
|
||||||
if err != nil {
|
if exist {
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, loaded := c.udpTargetConns.Load(targetAddr)
|
|
||||||
if loaded {
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
conn, err = c.srv.dial(ctx, "udp", targetAddr)
|
conn, err := c.srv.dial(ctx, "udp", targetAddr.hostPort())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
c.udpTargetConns.Store(targetAddr, conn)
|
c.udpTargetConns[targetAddr] = conn
|
||||||
|
|
||||||
// target -> client
|
// target -> client
|
||||||
go func() {
|
go func() {
|
||||||
buf := make([]byte, bufferSize)
|
buf := make([]byte, bufferSize)
|
||||||
addr := socksAddr{addrType: getAddrType(host), addr: host, port: port}
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
err := c.handleUDPResponse(clientConn, addr, conn, buf)
|
err := c.handleUDPResponse(clientConn, targetAddr, conn, buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if isTimeout(err) {
|
if isTimeout(err) {
|
||||||
continue
|
continue
|
||||||
@ -414,18 +408,17 @@ func (c *Conn) handleUDPRequest(
|
|||||||
return fmt.Errorf("parse udp request: %w", err)
|
return fmt.Errorf("parse udp request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
targetAddr := req.addr.hostPort()
|
targetConn, err := c.getOrDialTargetConn(ctx, clientConn, req.addr)
|
||||||
targetConn, err := c.getOrDialTargetConn(ctx, clientConn, targetAddr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("dial target %s fail: %w", targetAddr, err)
|
return fmt.Errorf("dial target %s fail: %w", req.addr, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
nn, err := targetConn.Write(data)
|
nn, err := targetConn.Write(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("write to target %s fail: %w", targetAddr, err)
|
return fmt.Errorf("write to target %s fail: %w", req.addr, err)
|
||||||
}
|
}
|
||||||
if nn != len(data) {
|
if nn != len(data) {
|
||||||
return fmt.Errorf("write to target %s fail: %w", targetAddr, io.ErrShortWrite)
|
return fmt.Errorf("write to target %s fail: %w", req.addr, io.ErrShortWrite)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -652,10 +645,15 @@ func (s socksAddr) marshal() ([]byte, error) {
|
|||||||
pkt = binary.BigEndian.AppendUint16(pkt, s.port)
|
pkt = binary.BigEndian.AppendUint16(pkt, s.port)
|
||||||
return pkt, nil
|
return pkt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s socksAddr) hostPort() string {
|
func (s socksAddr) hostPort() string {
|
||||||
return net.JoinHostPort(s.addr, strconv.Itoa(int(s.port)))
|
return net.JoinHostPort(s.addr, strconv.Itoa(int(s.port)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s socksAddr) String() string {
|
||||||
|
return s.hostPort()
|
||||||
|
}
|
||||||
|
|
||||||
// response contains the contents of
|
// response contains the contents of
|
||||||
// a response packet sent from the proxy
|
// a response packet sent from the proxy
|
||||||
// to the client.
|
// to the client.
|
||||||
|
@ -169,12 +169,25 @@ func TestReadPassword(t *testing.T) {
|
|||||||
|
|
||||||
func TestUDP(t *testing.T) {
|
func TestUDP(t *testing.T) {
|
||||||
// backend UDP server which we'll use SOCKS5 to connect to
|
// backend UDP server which we'll use SOCKS5 to connect to
|
||||||
listener, err := net.ListenPacket("udp", ":0")
|
newUDPEchoServer := func() net.PacketConn {
|
||||||
if err != nil {
|
listener, err := net.ListenPacket("udp", ":0")
|
||||||
t.Fatal(err)
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
go udpEchoServer(listener)
|
||||||
|
return listener
|
||||||
}
|
}
|
||||||
backendServerPort := listener.LocalAddr().(*net.UDPAddr).Port
|
|
||||||
go udpEchoServer(listener)
|
const echoServerNumber = 3
|
||||||
|
echoServerListener := make([]net.PacketConn, echoServerNumber)
|
||||||
|
for i := 0; i < echoServerNumber; i++ {
|
||||||
|
echoServerListener[i] = newUDPEchoServer()
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
for i := 0; i < echoServerNumber; i++ {
|
||||||
|
_ = echoServerListener[i].Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// SOCKS5 server
|
// SOCKS5 server
|
||||||
socks5, err := net.Listen("tcp", ":0")
|
socks5, err := net.Listen("tcp", ":0")
|
||||||
@ -184,84 +197,93 @@ func TestUDP(t *testing.T) {
|
|||||||
socks5Port := socks5.Addr().(*net.TCPAddr).Port
|
socks5Port := socks5.Addr().(*net.TCPAddr).Port
|
||||||
go socks5Server(socks5)
|
go socks5Server(socks5)
|
||||||
|
|
||||||
// net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request
|
// make a socks5 udpAssociate conn
|
||||||
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port))
|
newUdpAssociateConn := func() (socks5Conn net.Conn, socks5UDPAddr socksAddr) {
|
||||||
if err != nil {
|
// net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request
|
||||||
t.Fatal(err)
|
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port))
|
||||||
}
|
if err != nil {
|
||||||
_, err = conn.Write([]byte{0x05, 0x01, 0x00}) // client hello with no auth
|
t.Fatal(err)
|
||||||
if err != nil {
|
}
|
||||||
t.Fatal(err)
|
_, err = conn.Write([]byte{socks5Version, 0x01, noAuthRequired}) // client hello with no auth
|
||||||
}
|
if err != nil {
|
||||||
buf := make([]byte, 1024)
|
t.Fatal(err)
|
||||||
n, err := conn.Read(buf) // server hello
|
}
|
||||||
if err != nil {
|
buf := make([]byte, 1024)
|
||||||
t.Fatal(err)
|
n, err := conn.Read(buf) // server hello
|
||||||
}
|
if err != nil {
|
||||||
if n != 2 || buf[0] != 0x05 || buf[1] != 0x00 {
|
t.Fatal(err)
|
||||||
t.Fatalf("got: %q want: 0x05 0x00", buf[:n])
|
}
|
||||||
|
if n != 2 || buf[0] != socks5Version || buf[1] != noAuthRequired {
|
||||||
|
t.Fatalf("got: %q want: 0x05 0x00", buf[:n])
|
||||||
|
}
|
||||||
|
|
||||||
|
targetAddr := socksAddr{addrType: ipv4, addr: "0.0.0.0", port: 0}
|
||||||
|
targetAddrPkt, err := targetAddr.marshal()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
_, err = conn.Write(append([]byte{socks5Version, byte(udpAssociate), 0x00}, targetAddrPkt...)) // client reqeust
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err = conn.Read(buf) // server response
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if n < 3 || !bytes.Equal(buf[:3], []byte{socks5Version, 0x00, 0x00}) {
|
||||||
|
t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n])
|
||||||
|
}
|
||||||
|
udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n]))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, udpProxySocksAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
targetAddr := socksAddr{
|
conn, udpProxySocksAddr := newUdpAssociateConn()
|
||||||
addrType: domainName,
|
defer conn.Close()
|
||||||
addr: "localhost",
|
|
||||||
port: uint16(backendServerPort),
|
|
||||||
}
|
|
||||||
targetAddrPkt, err := targetAddr.marshal()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
_, err = conn.Write(append([]byte{0x05, 0x03, 0x00}, targetAddrPkt...)) // client reqeust
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
n, err = conn.Read(buf) // server response
|
sendUDPAndWaitResponse := func(socks5UDPConn net.Conn, addr socksAddr, body []byte) (responseBody []byte) {
|
||||||
if err != nil {
|
udpPayload, err := (&udpRequest{addr: addr}).marshal()
|
||||||
t.Fatal(err)
|
if err != nil {
|
||||||
}
|
t.Fatal(err)
|
||||||
if n < 3 || !bytes.Equal(buf[:3], []byte{0x05, 0x00, 0x00}) {
|
}
|
||||||
t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n])
|
udpPayload = append(udpPayload, body...)
|
||||||
}
|
_, err = socks5UDPConn.Write(udpPayload)
|
||||||
udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n]))
|
if err != nil {
|
||||||
if err != nil {
|
t.Fatal(err)
|
||||||
t.Fatal(err)
|
}
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
n, err := socks5UDPConn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
_, responseBody, err = parseUDPRequest(buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return responseBody
|
||||||
}
|
}
|
||||||
|
|
||||||
udpProxyAddr, err := net.ResolveUDPAddr("udp", udpProxySocksAddr.hostPort())
|
udpProxyAddr, err := net.ResolveUDPAddr("udp", udpProxySocksAddr.hostPort())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
udpConn, err := net.DialUDP("udp", nil, udpProxyAddr)
|
socks5UDPConn, err := net.DialUDP("udp", nil, udpProxyAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
udpPayload, err := (&udpRequest{addr: targetAddr}).marshal()
|
defer socks5UDPConn.Close()
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
for i := 0; i < echoServerNumber; i++ {
|
||||||
}
|
port := echoServerListener[i].LocalAddr().(*net.UDPAddr).Port
|
||||||
udpPayload = append(udpPayload, []byte("Test")...)
|
addr := socksAddr{addrType: ipv4, addr: "127.0.0.1", port: uint16(port)}
|
||||||
_, err = udpConn.Write(udpPayload) // send udp package
|
requestBody := []byte(fmt.Sprintf("Test %d", i))
|
||||||
if err != nil {
|
responseBody := sendUDPAndWaitResponse(socks5UDPConn, addr, requestBody)
|
||||||
t.Fatal(err)
|
if !bytes.Equal(requestBody, responseBody) {
|
||||||
}
|
t.Fatalf("got: %q want: %q", responseBody, requestBody)
|
||||||
n, _, err = udpConn.ReadFrom(buf)
|
}
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
_, responseBody, err := parseUDPRequest(buf[:n]) // read udp response
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if string(responseBody) != "Test" {
|
|
||||||
t.Fatalf("got: %q want: Test", responseBody)
|
|
||||||
}
|
|
||||||
err = udpConn.Close()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
err = conn.Close()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user