diff --git a/net/socks5/socks5.go b/net/socks5/socks5.go index b774ebe24..0d651537f 100644 --- a/net/socks5/socks5.go +++ b/net/socks5/socks5.go @@ -13,8 +13,10 @@ package socks5 import ( + "bytes" "context" "encoding/binary" + "errors" "fmt" "io" "log" @@ -121,7 +123,7 @@ func (s *Server) Serve(l net.Listener) error { } go func() { defer c.Close() - conn := &Conn{clientConn: c, srv: s} + conn := &Conn{logf: s.Logf, clientConn: c, srv: s} err := conn.Run() if err != nil { s.logf("client connection failed: %v", err) @@ -136,9 +138,12 @@ type Conn struct { // The struct is filled by each of the internal // methods in turn as the transaction progresses. + logf logger.Logf srv *Server clientConn net.Conn request *request + + udpClientAddr net.Addr } // Run starts the new connection. @@ -172,58 +177,59 @@ func (c *Conn) Run() error { func (c *Conn) handleRequest() error { req, err := parseClientRequest(c.clientConn) if err != nil { - res := &response{reply: generalFailure} + res := errorResponse(generalFailure) buf, _ := res.marshal() c.clientConn.Write(buf) return err } - if req.command != connect { - res := &response{reply: commandNotSupported} + + c.request = req + switch req.command { + case connect: + return c.handleTCP() + case udpAssociate: + return c.handleUDP() + default: + res := errorResponse(commandNotSupported) buf, _ := res.marshal() c.clientConn.Write(buf) return fmt.Errorf("unsupported command %v", req.command) } - c.request = req +} +func (c *Conn) handleTCP() error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() srv, err := c.srv.dial( ctx, "tcp", - net.JoinHostPort(c.request.destination, strconv.Itoa(int(c.request.port))), + c.request.destination.hostPort(), ) if err != nil { - res := &response{reply: generalFailure} + res := errorResponse(generalFailure) buf, _ := res.marshal() c.clientConn.Write(buf) return err } defer srv.Close() - serverAddr, serverPortStr, err := net.SplitHostPort(srv.LocalAddr().String()) + + localAddr := srv.LocalAddr().String() + serverAddr, serverPort, err := splitHostPort(localAddr) if err != nil { return err } - serverPort, _ := strconv.Atoi(serverPortStr) - var bindAddrType addrType - if ip := net.ParseIP(serverAddr); ip != nil { - if ip.To4() != nil { - bindAddrType = ipv4 - } else { - bindAddrType = ipv6 - } - } else { - bindAddrType = domainName - } res := &response{ - reply: success, - bindAddrType: bindAddrType, - bindAddr: serverAddr, - bindPort: uint16(serverPort), + reply: success, + bindAddr: socksAddr{ + addrType: getAddrType(serverAddr), + addr: serverAddr, + port: serverPort, + }, } buf, err := res.marshal() if err != nil { - res = &response{reply: generalFailure} + res = errorResponse(generalFailure) buf, _ = res.marshal() } c.clientConn.Write(buf) @@ -246,6 +252,208 @@ func (c *Conn) handleRequest() error { return <-errc } +func (c *Conn) handleUDP() error { + // The DST.ADDR and DST.PORT fields contain the address and port that + // the client expects to use to send UDP datagrams on for the + // association. The server MAY use this information to limit access + // to the association. + // @see Page 6, https://datatracker.ietf.org/doc/html/rfc1928. + // + // We do NOT limit the access from the client currently in this implementation. + _ = c.request.destination + + addr := c.clientConn.LocalAddr() + host, _, err := net.SplitHostPort(addr.String()) + if err != nil { + return err + } + clientUDPConn, err := net.ListenPacket("udp", net.JoinHostPort(host, "0")) + if err != nil { + res := errorResponse(generalFailure) + buf, _ := res.marshal() + c.clientConn.Write(buf) + return err + } + defer clientUDPConn.Close() + + serverUDPConn, err := net.ListenPacket("udp", "[::]:0") + if err != nil { + res := errorResponse(generalFailure) + buf, _ := res.marshal() + c.clientConn.Write(buf) + return err + } + defer serverUDPConn.Close() + + bindAddr, bindPort, err := splitHostPort(clientUDPConn.LocalAddr().String()) + if err != nil { + return err + } + + res := &response{ + reply: success, + bindAddr: socksAddr{ + addrType: getAddrType(bindAddr), + addr: bindAddr, + port: bindPort, + }, + } + buf, err := res.marshal() + if err != nil { + res = errorResponse(generalFailure) + buf, _ = res.marshal() + } + c.clientConn.Write(buf) + + return c.transferUDP(c.clientConn, clientUDPConn, serverUDPConn) +} + +func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, targetConn net.PacketConn) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + const bufferSize = 8 * 1024 + const readTimeout = 5 * time.Second + + // client -> target + go func() { + defer cancel() + buf := make([]byte, bufferSize) + for { + select { + case <-ctx.Done(): + return + default: + err := c.handleUDPRequest(clientConn, targetConn, buf, readTimeout) + if err != nil { + if isTimeout(err) { + continue + } + if errors.Is(err, net.ErrClosed) { + return + } + c.logf("udp transfer: handle udp request fail: %v", err) + } + } + } + }() + + // target -> client + go func() { + defer cancel() + buf := make([]byte, bufferSize) + for { + select { + case <-ctx.Done(): + return + default: + err := c.handleUDPResponse(targetConn, clientConn, buf, readTimeout) + if err != nil { + if isTimeout(err) { + continue + } + if errors.Is(err, net.ErrClosed) { + return + } + c.logf("udp transfer: handle udp response fail: %v", err) + } + } + } + }() + + // A UDP association terminates when the TCP connection that the UDP + // ASSOCIATE request arrived on terminates. RFC1928 + _, err := io.Copy(io.Discard, associatedTCP) + if err != nil { + err = fmt.Errorf("udp associated tcp conn: %w", err) + } + return err +} + +func (c *Conn) handleUDPRequest( + clientConn net.PacketConn, + targetConn net.PacketConn, + buf []byte, + readTimeout time.Duration, +) error { + // add a deadline for the read to avoid blocking forever + _ = clientConn.SetReadDeadline(time.Now().Add(readTimeout)) + n, addr, err := clientConn.ReadFrom(buf) + if err != nil { + return fmt.Errorf("read from client: %w", err) + } + c.udpClientAddr = addr + req, data, err := parseUDPRequest(buf[:n]) + if err != nil { + return fmt.Errorf("parse udp request: %w", err) + } + targetAddr, err := net.ResolveUDPAddr("udp", req.addr.hostPort()) + if err != nil { + c.logf("resolve target addr fail: %v", err) + } + + nn, err := targetConn.WriteTo(data, targetAddr) + if err != nil { + return fmt.Errorf("write to target %s fail: %w", targetAddr, err) + } + if nn != len(data) { + return fmt.Errorf("write to target %s fail: %w", targetAddr, io.ErrShortWrite) + } + return nil +} + +func (c *Conn) handleUDPResponse( + targetConn net.PacketConn, + clientConn net.PacketConn, + buf []byte, + readTimeout time.Duration, +) error { + // add a deadline for the read to avoid blocking forever + _ = targetConn.SetReadDeadline(time.Now().Add(readTimeout)) + n, addr, err := targetConn.ReadFrom(buf) + if err != nil { + return fmt.Errorf("read from target: %w", err) + } + host, port, err := splitHostPort(addr.String()) + if err != nil { + return fmt.Errorf("split host port: %w", err) + } + hdr := udpRequest{addr: socksAddr{addrType: getAddrType(host), addr: host, port: port}} + pkt, err := hdr.marshal() + if err != nil { + return fmt.Errorf("marshal udp request: %w", err) + } + data := append(pkt, buf[:n]...) + // use addr from client to send back + nn, err := clientConn.WriteTo(data, c.udpClientAddr) + if err != nil { + return fmt.Errorf("write to client: %w", err) + } + if nn != len(data) { + return fmt.Errorf("write to client: %w", io.ErrShortWrite) + } + return nil +} + +func isTimeout(err error) bool { + terr, ok := errors.Unwrap(err).(interface{ Timeout() bool }) + return ok && terr.Timeout() +} + +func splitHostPort(hostport string) (host string, port uint16, err error) { + host, portStr, err := net.SplitHostPort(hostport) + if err != nil { + return "", 0, err + } + portInt, err := strconv.Atoi(portStr) + if err != nil { + return "", 0, err + } + if portInt < 0 || portInt > 65535 { + return "", 0, fmt.Errorf("invalid port number %d", portInt) + } + return host, uint16(portInt), nil +} + // parseClientGreeting parses a request initiation packet. func parseClientGreeting(r io.Reader, authMethod byte) error { var hdr [2]byte @@ -295,114 +503,118 @@ func parseClientAuth(r io.Reader) (usr, pwd string, err error) { return string(usrBytes), string(pwdBytes), nil } +func getAddrType(addr string) addrType { + if ip := net.ParseIP(addr); ip != nil { + if ip.To4() != nil { + return ipv4 + } + return ipv6 + } + return domainName +} + // request represents data contained within a SOCKS5 // connection request packet. type request struct { - command commandType - destination string - port uint16 - destAddrType addrType + command commandType + destination socksAddr } // parseClientRequest converts raw packet bytes into a // SOCKS5Request struct. func parseClientRequest(r io.Reader) (*request, error) { - var hdr [4]byte + var hdr [3]byte _, err := io.ReadFull(r, hdr[:]) if err != nil { return nil, fmt.Errorf("could not read packet header") } cmd := hdr[1] - destAddrType := addrType(hdr[3]) + destination, err := parseSocksAddr(r) + return &request{ + command: commandType(cmd), + destination: destination, + }, err +} + +type socksAddr struct { + addrType addrType + addr string + port uint16 +} + +var zeroSocksAddr = socksAddr{addrType: ipv4, addr: "0.0.0.0", port: 0} + +func parseSocksAddr(r io.Reader) (addr socksAddr, err error) { + var addrTypeData [1]byte + _, err = io.ReadFull(r, addrTypeData[:]) + if err != nil { + return socksAddr{}, fmt.Errorf("could not read address type") + } + + dstAddrType := addrType(addrTypeData[0]) var destination string - var port uint16 - - if destAddrType == ipv4 { + switch dstAddrType { + case ipv4: var ip [4]byte _, err = io.ReadFull(r, ip[:]) if err != nil { - return nil, fmt.Errorf("could not read IPv4 address") + return socksAddr{}, fmt.Errorf("could not read IPv4 address") } destination = net.IP(ip[:]).String() - } else if destAddrType == domainName { + case domainName: var dstSizeByte [1]byte _, err = io.ReadFull(r, dstSizeByte[:]) if err != nil { - return nil, fmt.Errorf("could not read domain name size") + return socksAddr{}, fmt.Errorf("could not read domain name size") } dstSize := int(dstSizeByte[0]) domainName := make([]byte, dstSize) _, err = io.ReadFull(r, domainName) if err != nil { - return nil, fmt.Errorf("could not read domain name") + return socksAddr{}, fmt.Errorf("could not read domain name") } destination = string(domainName) - } else if destAddrType == ipv6 { + case ipv6: var ip [16]byte _, err = io.ReadFull(r, ip[:]) if err != nil { - return nil, fmt.Errorf("could not read IPv6 address") + return socksAddr{}, fmt.Errorf("could not read IPv6 address") } destination = net.IP(ip[:]).String() - } else { - return nil, fmt.Errorf("unsupported address type") + default: + return socksAddr{}, fmt.Errorf("unsupported address type") } var portBytes [2]byte _, err = io.ReadFull(r, portBytes[:]) if err != nil { - return nil, fmt.Errorf("could not read port") + return socksAddr{}, fmt.Errorf("could not read port") } - port = binary.BigEndian.Uint16(portBytes[:]) - - return &request{ - command: commandType(cmd), - destination: destination, - port: port, - destAddrType: destAddrType, + port := binary.BigEndian.Uint16(portBytes[:]) + return socksAddr{ + addrType: dstAddrType, + addr: destination, + port: port, }, nil } -// response contains the contents of -// a response packet sent from the proxy -// to the client. -type response struct { - reply replyCode - bindAddrType addrType - bindAddr string - bindPort uint16 -} - -// marshal converts a SOCKS5Response struct into -// a packet. If res.reply == Success, it may throw an error on -// receiving an invalid bind address. Otherwise, it will not throw. -func (res *response) marshal() ([]byte, error) { - pkt := make([]byte, 4) - pkt[0] = socks5Version - pkt[1] = byte(res.reply) - pkt[2] = 0 // null reserved byte - pkt[3] = byte(res.bindAddrType) - - if res.reply != success { - return pkt, nil - } - +func (s socksAddr) marshal() ([]byte, error) { var addr []byte - switch res.bindAddrType { + switch s.addrType { case ipv4: - addr = net.ParseIP(res.bindAddr).To4() + addr = net.ParseIP(s.addr).To4() if addr == nil { return nil, fmt.Errorf("invalid IPv4 address for binding") } case domainName: - if len(res.bindAddr) > 255 { + if len(s.addr) > 255 { return nil, fmt.Errorf("invalid domain name for binding") } - addr = make([]byte, 0, len(res.bindAddr)+1) - addr = append(addr, byte(len(res.bindAddr))) - addr = append(addr, []byte(res.bindAddr)...) + addr = make([]byte, 0, len(s.addr)+1) + addr = append(addr, byte(len(s.addr))) + addr = append(addr, []byte(s.addr)...) case ipv6: - addr = net.ParseIP(res.bindAddr).To16() + addr = net.ParseIP(s.addr).To16() if addr == nil { return nil, fmt.Errorf("invalid IPv6 address for binding") } @@ -410,8 +622,86 @@ func (res *response) marshal() ([]byte, error) { return nil, fmt.Errorf("unsupported address type") } + pkt := []byte{byte(s.addrType)} pkt = append(pkt, addr...) - pkt = binary.BigEndian.AppendUint16(pkt, uint16(res.bindPort)) - + pkt = binary.BigEndian.AppendUint16(pkt, s.port) return pkt, nil } +func (s socksAddr) hostPort() string { + return net.JoinHostPort(s.addr, strconv.Itoa(int(s.port))) +} + +// response contains the contents of +// a response packet sent from the proxy +// to the client. +type response struct { + reply replyCode + bindAddr socksAddr +} + +func errorResponse(code replyCode) *response { + return &response{reply: code, bindAddr: zeroSocksAddr} +} + +// marshal converts a SOCKS5Response struct into +// a packet. If res.reply == Success, it may throw an error on +// receiving an invalid bind address. Otherwise, it will not throw. +func (res *response) marshal() ([]byte, error) { + pkt := make([]byte, 3) + pkt[0] = socks5Version + pkt[1] = byte(res.reply) + pkt[2] = 0 // null reserved byte + + addrPkt, err := res.bindAddr.marshal() + if err != nil { + return nil, err + } + + return append(pkt, addrPkt...), nil +} + +type udpRequest struct { + frag byte + addr socksAddr +} + +// +----+------+------+----------+----------+----------+ +// |RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA | +// +----+------+------+----------+----------+----------+ +// | 2 | 1 | 1 | Variable | 2 | Variable | +// +----+------+------+----------+----------+----------+ +func parseUDPRequest(data []byte) (*udpRequest, []byte, error) { + if len(data) < 4 { + return nil, nil, fmt.Errorf("invalid packet length") + } + + // reserved bytes + if !(data[0] == 0 && data[1] == 0) { + return nil, nil, fmt.Errorf("invalid udp request header") + } + + frag := data[2] + + reader := bytes.NewReader(data[3:]) + addr, err := parseSocksAddr(reader) + bodyLen := reader.Len() // (*bytes.Reader).Len() return unread data length + body := data[len(data)-bodyLen:] + return &udpRequest{ + frag: frag, + addr: addr, + }, body, err +} + +func (u *udpRequest) marshal() ([]byte, error) { + pkt := make([]byte, 3) + pkt[0] = 0 + pkt[1] = 0 + pkt[2] = u.frag + + addrPkt, err := u.addr.marshal() + if err != nil { + return nil, err + } + + return append(pkt, addrPkt...), nil +} diff --git a/net/socks5/socks5_test.go b/net/socks5/socks5_test.go index 201a66575..11ea59d4b 100644 --- a/net/socks5/socks5_test.go +++ b/net/socks5/socks5_test.go @@ -4,6 +4,7 @@ package socks5 import ( + "bytes" "errors" "fmt" "io" @@ -32,6 +33,19 @@ func backendServer(listener net.Listener) { listener.Close() } +func udpEchoServer(conn net.PacketConn) { + var buf [1024]byte + n, addr, err := conn.ReadFrom(buf[:]) + if err != nil { + panic(err) + } + _, err = conn.WriteTo(buf[:n], addr) + if err != nil { + panic(err) + } + conn.Close() +} + func TestRead(t *testing.T) { // backend server which we'll use SOCKS5 to connect to listener, err := net.Listen("tcp", ":0") @@ -152,3 +166,102 @@ func TestReadPassword(t *testing.T) { t.Fatal(err) } } + +func TestUDP(t *testing.T) { + // backend UDP server which we'll use SOCKS5 to connect to + listener, err := net.ListenPacket("udp", ":0") + if err != nil { + t.Fatal(err) + } + backendServerPort := listener.LocalAddr().(*net.UDPAddr).Port + go udpEchoServer(listener) + + // SOCKS5 server + socks5, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + socks5Port := socks5.Addr().(*net.TCPAddr).Port + go socks5Server(socks5) + + // net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port)) + if err != nil { + t.Fatal(err) + } + _, err = conn.Write([]byte{0x05, 0x01, 0x00}) // client hello with no auth + if err != nil { + t.Fatal(err) + } + buf := make([]byte, 1024) + n, err := conn.Read(buf) // server hello + if err != nil { + t.Fatal(err) + } + if n != 2 || buf[0] != 0x05 || buf[1] != 0x00 { + t.Fatalf("got: %q want: 0x05 0x00", buf[:n]) + } + + targetAddr := socksAddr{ + addrType: domainName, + addr: "localhost", + port: uint16(backendServerPort), + } + targetAddrPkt, err := targetAddr.marshal() + if err != nil { + t.Fatal(err) + } + _, err = conn.Write(append([]byte{0x05, 0x03, 0x00}, targetAddrPkt...)) // client reqeust + if err != nil { + t.Fatal(err) + } + + n, err = conn.Read(buf) // server response + if err != nil { + t.Fatal(err) + } + if n < 3 || !bytes.Equal(buf[:3], []byte{0x05, 0x00, 0x00}) { + t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n]) + } + udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n])) + if err != nil { + t.Fatal(err) + } + + udpProxyAddr, err := net.ResolveUDPAddr("udp", udpProxySocksAddr.hostPort()) + if err != nil { + t.Fatal(err) + } + udpConn, err := net.DialUDP("udp", nil, udpProxyAddr) + if err != nil { + t.Fatal(err) + } + udpPayload, err := (&udpRequest{addr: targetAddr}).marshal() + if err != nil { + t.Fatal(err) + } + udpPayload = append(udpPayload, []byte("Test")...) + _, err = udpConn.Write(udpPayload) // send udp package + if err != nil { + t.Fatal(err) + } + n, _, err = udpConn.ReadFrom(buf) + if err != nil { + t.Fatal(err) + } + _, responseBody, err := parseUDPRequest(buf[:n]) // read udp response + if err != nil { + t.Fatal(err) + } + if string(responseBody) != "Test" { + t.Fatalf("got: %q want: Test", responseBody) + } + err = udpConn.Close() + if err != nil { + t.Fatal(err) + } + err = conn.Close() + if err != nil { + t.Fatal(err) + } +}