tailscale/net/socks5/socks5_test.go
VimT e3f047618b net/socks5: support UDP
Updates #7581

Signed-off-by: VimT <me@vimt.me>
2024-08-05 09:25:24 -07:00

268 lines
5.7 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package socks5
import (
"bytes"
"errors"
"fmt"
"io"
"net"
"testing"
"golang.org/x/net/proxy"
)
func socks5Server(listener net.Listener) {
var server Server
err := server.Serve(listener)
if err != nil {
panic(err)
}
listener.Close()
}
func backendServer(listener net.Listener) {
conn, err := listener.Accept()
if err != nil {
panic(err)
}
conn.Write([]byte("Test"))
conn.Close()
listener.Close()
}
func udpEchoServer(conn net.PacketConn) {
var buf [1024]byte
n, addr, err := conn.ReadFrom(buf[:])
if err != nil {
panic(err)
}
_, err = conn.WriteTo(buf[:n], addr)
if err != nil {
panic(err)
}
conn.Close()
}
func TestRead(t *testing.T) {
// backend server which we'll use SOCKS5 to connect to
listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
backendServerPort := listener.Addr().(*net.TCPAddr).Port
go backendServer(listener)
// SOCKS5 server
socks5, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
socks5Port := socks5.Addr().(*net.TCPAddr).Port
go socks5Server(socks5)
addr := fmt.Sprintf("localhost:%d", socks5Port)
socksDialer, err := proxy.SOCKS5("tcp", addr, nil, proxy.Direct)
if err != nil {
t.Fatal(err)
}
addr = fmt.Sprintf("localhost:%d", backendServerPort)
conn, err := socksDialer.Dial("tcp", addr)
if err != nil {
t.Fatal(err)
}
buf := make([]byte, 4)
_, err = io.ReadFull(conn, buf)
if err != nil {
t.Fatal(err)
}
if string(buf) != "Test" {
t.Fatalf("got: %q want: Test", buf)
}
err = conn.Close()
if err != nil {
t.Fatal(err)
}
}
func TestReadPassword(t *testing.T) {
// backend server which we'll use SOCKS5 to connect to
ln, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
backendServerPort := ln.Addr().(*net.TCPAddr).Port
go backendServer(ln)
socks5ln, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
socks5ln.Close()
})
auth := &proxy.Auth{User: "foo", Password: "bar"}
go func() {
s := Server{Username: auth.User, Password: auth.Password}
err := s.Serve(socks5ln)
if err != nil && !errors.Is(err, net.ErrClosed) {
panic(err)
}
}()
addr := fmt.Sprintf("localhost:%d", socks5ln.Addr().(*net.TCPAddr).Port)
if d, err := proxy.SOCKS5("tcp", addr, nil, proxy.Direct); err != nil {
t.Fatal(err)
} else {
if _, err := d.Dial("tcp", addr); err == nil {
t.Fatal("expected no-auth dial error")
}
}
badPwd := &proxy.Auth{User: "foo", Password: "not right"}
if d, err := proxy.SOCKS5("tcp", addr, badPwd, proxy.Direct); err != nil {
t.Fatal(err)
} else {
if _, err := d.Dial("tcp", addr); err == nil {
t.Fatal("expected bad password dial error")
}
}
badUsr := &proxy.Auth{User: "not right", Password: "bar"}
if d, err := proxy.SOCKS5("tcp", addr, badUsr, proxy.Direct); err != nil {
t.Fatal(err)
} else {
if _, err := d.Dial("tcp", addr); err == nil {
t.Fatal("expected bad username dial error")
}
}
socksDialer, err := proxy.SOCKS5("tcp", addr, auth, proxy.Direct)
if err != nil {
t.Fatal(err)
}
addr = fmt.Sprintf("localhost:%d", backendServerPort)
conn, err := socksDialer.Dial("tcp", addr)
if err != nil {
t.Fatal(err)
}
buf := make([]byte, 4)
if _, err := io.ReadFull(conn, buf); err != nil {
t.Fatal(err)
}
if string(buf) != "Test" {
t.Fatalf("got: %q want: Test", buf)
}
if err := conn.Close(); err != nil {
t.Fatal(err)
}
}
func TestUDP(t *testing.T) {
// backend UDP server which we'll use SOCKS5 to connect to
listener, err := net.ListenPacket("udp", ":0")
if err != nil {
t.Fatal(err)
}
backendServerPort := listener.LocalAddr().(*net.UDPAddr).Port
go udpEchoServer(listener)
// SOCKS5 server
socks5, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
socks5Port := socks5.Addr().(*net.TCPAddr).Port
go socks5Server(socks5)
// net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port))
if err != nil {
t.Fatal(err)
}
_, err = conn.Write([]byte{0x05, 0x01, 0x00}) // client hello with no auth
if err != nil {
t.Fatal(err)
}
buf := make([]byte, 1024)
n, err := conn.Read(buf) // server hello
if err != nil {
t.Fatal(err)
}
if n != 2 || buf[0] != 0x05 || buf[1] != 0x00 {
t.Fatalf("got: %q want: 0x05 0x00", buf[:n])
}
targetAddr := socksAddr{
addrType: domainName,
addr: "localhost",
port: uint16(backendServerPort),
}
targetAddrPkt, err := targetAddr.marshal()
if err != nil {
t.Fatal(err)
}
_, err = conn.Write(append([]byte{0x05, 0x03, 0x00}, targetAddrPkt...)) // client reqeust
if err != nil {
t.Fatal(err)
}
n, err = conn.Read(buf) // server response
if err != nil {
t.Fatal(err)
}
if n < 3 || !bytes.Equal(buf[:3], []byte{0x05, 0x00, 0x00}) {
t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n])
}
udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n]))
if err != nil {
t.Fatal(err)
}
udpProxyAddr, err := net.ResolveUDPAddr("udp", udpProxySocksAddr.hostPort())
if err != nil {
t.Fatal(err)
}
udpConn, err := net.DialUDP("udp", nil, udpProxyAddr)
if err != nil {
t.Fatal(err)
}
udpPayload, err := (&udpRequest{addr: targetAddr}).marshal()
if err != nil {
t.Fatal(err)
}
udpPayload = append(udpPayload, []byte("Test")...)
_, err = udpConn.Write(udpPayload) // send udp package
if err != nil {
t.Fatal(err)
}
n, _, err = udpConn.ReadFrom(buf)
if err != nil {
t.Fatal(err)
}
_, responseBody, err := parseUDPRequest(buf[:n]) // read udp response
if err != nil {
t.Fatal(err)
}
if string(responseBody) != "Test" {
t.Fatalf("got: %q want: Test", responseBody)
}
err = udpConn.Close()
if err != nil {
t.Fatal(err)
}
err = conn.Close()
if err != nil {
t.Fatal(err)
}
}