mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-26 03:25:35 +00:00
2fb087891b
Customer reported an issue where the connections were not closing, and would instead just stay open. This commit makes it so that we close out the connection regardless of what error we see. I've verified locally that it fixes the issue, we should add a test for this. Signed-off-by: Maisem Ali <maisem@tailscale.com>
372 lines
9.1 KiB
Go
372 lines
9.1 KiB
Go
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
// Package socks5 is a SOCKS5 server implementation.
|
|
//
|
|
// This is used for userspace networking in Tailscale. Specifically,
|
|
// this is used for dialing out of the machine to other nodes, without
|
|
// the host kernel's involvement, so it doesn't proper routing tables,
|
|
// TUN, IPv6, etc. This package is meant to only handle the SOCKS5 protocol
|
|
// details and not any integration with Tailscale internals itself.
|
|
//
|
|
// The glue between this package and Tailscale is in net/socks5/tssocks.
|
|
package socks5
|
|
|
|
import (
|
|
"context"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"strconv"
|
|
"time"
|
|
|
|
"tailscale.com/types/logger"
|
|
)
|
|
|
|
const (
|
|
noAuthRequired byte = 0
|
|
noAcceptableAuth byte = 255
|
|
|
|
// socks5Version is the byte that represents the SOCKS version
|
|
// in requests.
|
|
socks5Version byte = 5
|
|
)
|
|
|
|
// commandType are the bytes sent in SOCKS5 packets
|
|
// that represent the kind of connection the client needs.
|
|
type commandType byte
|
|
|
|
// The set of valid SOCKS5 commands as described in RFC 1928.
|
|
const (
|
|
connect commandType = 1
|
|
bind commandType = 2
|
|
udpAssociate commandType = 3
|
|
)
|
|
|
|
// addrType are the bytes sent in SOCKS5 packets
|
|
// that represent particular address types.
|
|
type addrType byte
|
|
|
|
// The set of valid SOCKS5 address types as defined in RFC 1928.
|
|
const (
|
|
ipv4 addrType = 1
|
|
domainName addrType = 3
|
|
ipv6 addrType = 4
|
|
)
|
|
|
|
// replyCode are the bytes sent in SOCKS5 packets
|
|
// that represent replies from the server to a client
|
|
// request.
|
|
type replyCode byte
|
|
|
|
// The set of valid SOCKS5 reply types as per the RFC 1928.
|
|
const (
|
|
success replyCode = 0
|
|
generalFailure replyCode = 1
|
|
connectionNotAllowed replyCode = 2
|
|
networkUnreachable replyCode = 3
|
|
hostUnreachable replyCode = 4
|
|
connectionRefused replyCode = 5
|
|
ttlExpired replyCode = 6
|
|
commandNotSupported replyCode = 7
|
|
addrTypeNotSupported replyCode = 8
|
|
)
|
|
|
|
// Server is a SOCKS5 proxy server.
|
|
type Server struct {
|
|
// Logf optionally specifies the logger to use.
|
|
// If nil, the standard logger is used.
|
|
Logf logger.Logf
|
|
|
|
// Dialer optionally specifies the dialer to use for outgoing connections.
|
|
// If nil, the net package's standard dialer is used.
|
|
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
|
|
}
|
|
|
|
func (s *Server) dial(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
dial := s.Dialer
|
|
if dial == nil {
|
|
dialer := &net.Dialer{}
|
|
dial = dialer.DialContext
|
|
}
|
|
return dial(ctx, network, addr)
|
|
}
|
|
|
|
func (s *Server) logf(format string, args ...interface{}) {
|
|
logf := s.Logf
|
|
if logf == nil {
|
|
logf = log.Printf
|
|
}
|
|
logf(format, args...)
|
|
}
|
|
|
|
// Serve accepts and handles incoming connections on the given listener.
|
|
func (s *Server) Serve(l net.Listener) error {
|
|
defer l.Close()
|
|
for {
|
|
c, err := l.Accept()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
go func() {
|
|
defer c.Close()
|
|
conn := &Conn{clientConn: c, srv: s}
|
|
err := conn.Run()
|
|
if err != nil {
|
|
s.logf("client connection failed: %v", err)
|
|
}
|
|
}()
|
|
}
|
|
}
|
|
|
|
// Conn is a SOCKS5 connection for client to reach
|
|
// server.
|
|
type Conn struct {
|
|
// The struct is filled by each of the internal
|
|
// methods in turn as the transaction progresses.
|
|
|
|
srv *Server
|
|
clientConn net.Conn
|
|
request *request
|
|
}
|
|
|
|
// Run starts the new connection.
|
|
func (c *Conn) Run() error {
|
|
err := parseClientGreeting(c.clientConn)
|
|
if err != nil {
|
|
c.clientConn.Write([]byte{socks5Version, noAcceptableAuth})
|
|
return err
|
|
}
|
|
c.clientConn.Write([]byte{socks5Version, noAuthRequired})
|
|
return c.handleRequest()
|
|
}
|
|
|
|
func (c *Conn) handleRequest() error {
|
|
req, err := parseClientRequest(c.clientConn)
|
|
if err != nil {
|
|
res := &response{reply: generalFailure}
|
|
buf, _ := res.marshal()
|
|
c.clientConn.Write(buf)
|
|
return err
|
|
}
|
|
if req.command != connect {
|
|
res := &response{reply: commandNotSupported}
|
|
buf, _ := res.marshal()
|
|
c.clientConn.Write(buf)
|
|
return fmt.Errorf("unsupported command %v", req.command)
|
|
}
|
|
c.request = req
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
srv, err := c.srv.dial(
|
|
ctx,
|
|
"tcp",
|
|
net.JoinHostPort(c.request.destination, strconv.Itoa(int(c.request.port))),
|
|
)
|
|
if err != nil {
|
|
res := &response{reply: generalFailure}
|
|
buf, _ := res.marshal()
|
|
c.clientConn.Write(buf)
|
|
return err
|
|
}
|
|
defer srv.Close()
|
|
serverAddr, serverPortStr, err := net.SplitHostPort(srv.LocalAddr().String())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
serverPort, _ := strconv.Atoi(serverPortStr)
|
|
|
|
var bindAddrType addrType
|
|
if ip := net.ParseIP(serverAddr); ip != nil {
|
|
if ip.To4() != nil {
|
|
bindAddrType = ipv4
|
|
} else {
|
|
bindAddrType = ipv6
|
|
}
|
|
} else {
|
|
bindAddrType = domainName
|
|
}
|
|
res := &response{
|
|
reply: success,
|
|
bindAddrType: bindAddrType,
|
|
bindAddr: serverAddr,
|
|
bindPort: uint16(serverPort),
|
|
}
|
|
buf, err := res.marshal()
|
|
if err != nil {
|
|
res = &response{reply: generalFailure}
|
|
buf, _ = res.marshal()
|
|
}
|
|
c.clientConn.Write(buf)
|
|
|
|
errc := make(chan error, 2)
|
|
go func() {
|
|
_, err := io.Copy(c.clientConn, srv)
|
|
if err != nil {
|
|
err = fmt.Errorf("from backend to client: %w", err)
|
|
}
|
|
errc <- err
|
|
}()
|
|
go func() {
|
|
_, err := io.Copy(srv, c.clientConn)
|
|
if err != nil {
|
|
err = fmt.Errorf("from client to backend: %w", err)
|
|
}
|
|
errc <- err
|
|
}()
|
|
return <-errc
|
|
}
|
|
|
|
// parseClientGreeting parses a request initiation packet
|
|
// and returns a slice that contains the acceptable auth methods
|
|
// for the client.
|
|
func parseClientGreeting(r io.Reader) error {
|
|
var hdr [2]byte
|
|
_, err := io.ReadFull(r, hdr[:])
|
|
if err != nil {
|
|
return fmt.Errorf("could not read packet header")
|
|
}
|
|
if hdr[0] != socks5Version {
|
|
return fmt.Errorf("incompatible SOCKS version")
|
|
}
|
|
count := int(hdr[1])
|
|
methods := make([]byte, count)
|
|
_, err = io.ReadFull(r, methods)
|
|
if err != nil {
|
|
return fmt.Errorf("could not read methods")
|
|
}
|
|
for _, m := range methods {
|
|
if m == noAuthRequired {
|
|
return nil
|
|
}
|
|
}
|
|
return fmt.Errorf("no acceptable auth methods")
|
|
}
|
|
|
|
// request represents data contained within a SOCKS5
|
|
// connection request packet.
|
|
type request struct {
|
|
command commandType
|
|
destination string
|
|
port uint16
|
|
destAddrType addrType
|
|
}
|
|
|
|
// parseClientRequest converts raw packet bytes into a
|
|
// SOCKS5Request struct.
|
|
func parseClientRequest(r io.Reader) (*request, error) {
|
|
var hdr [4]byte
|
|
_, err := io.ReadFull(r, hdr[:])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not read packet header")
|
|
}
|
|
cmd := hdr[1]
|
|
destAddrType := addrType(hdr[3])
|
|
|
|
var destination string
|
|
var port uint16
|
|
|
|
if destAddrType == ipv4 {
|
|
var ip [4]byte
|
|
_, err = io.ReadFull(r, ip[:])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not read IPv4 address")
|
|
}
|
|
destination = net.IP(ip[:]).String()
|
|
} else if destAddrType == domainName {
|
|
var dstSizeByte [1]byte
|
|
_, err = io.ReadFull(r, dstSizeByte[:])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not read domain name size")
|
|
}
|
|
dstSize := int(dstSizeByte[0])
|
|
domainName := make([]byte, dstSize)
|
|
_, err = io.ReadFull(r, domainName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not read domain name")
|
|
}
|
|
destination = string(domainName)
|
|
} else if destAddrType == ipv6 {
|
|
var ip [16]byte
|
|
_, err = io.ReadFull(r, ip[:])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not read IPv6 address")
|
|
}
|
|
destination = net.IP(ip[:]).String()
|
|
} else {
|
|
return nil, fmt.Errorf("unsupported address type")
|
|
}
|
|
var portBytes [2]byte
|
|
_, err = io.ReadFull(r, portBytes[:])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not read port")
|
|
}
|
|
port = binary.BigEndian.Uint16(portBytes[:])
|
|
|
|
return &request{
|
|
command: commandType(cmd),
|
|
destination: destination,
|
|
port: port,
|
|
destAddrType: destAddrType,
|
|
}, nil
|
|
}
|
|
|
|
// response contains the contents of
|
|
// a response packet sent from the proxy
|
|
// to the client.
|
|
type response struct {
|
|
reply replyCode
|
|
bindAddrType addrType
|
|
bindAddr string
|
|
bindPort uint16
|
|
}
|
|
|
|
// marshal converts a SOCKS5Response struct into
|
|
// a packet. If res.reply == Success, it may throw an error on
|
|
// receiving an invalid bind address. Otherwise, it will not throw.
|
|
func (res *response) marshal() ([]byte, error) {
|
|
pkt := make([]byte, 4)
|
|
pkt[0] = socks5Version
|
|
pkt[1] = byte(res.reply)
|
|
pkt[2] = 0 // null reserved byte
|
|
pkt[3] = byte(res.bindAddrType)
|
|
|
|
if res.reply != success {
|
|
return pkt, nil
|
|
}
|
|
|
|
var addr []byte
|
|
switch res.bindAddrType {
|
|
case ipv4:
|
|
addr = net.ParseIP(res.bindAddr).To4()
|
|
if addr == nil {
|
|
return nil, fmt.Errorf("invalid IPv4 address for binding")
|
|
}
|
|
case domainName:
|
|
if len(res.bindAddr) > 255 {
|
|
return nil, fmt.Errorf("invalid domain name for binding")
|
|
}
|
|
addr = make([]byte, 0, len(res.bindAddr)+1)
|
|
addr = append(addr, byte(len(res.bindAddr)))
|
|
addr = append(addr, []byte(res.bindAddr)...)
|
|
case ipv6:
|
|
addr = net.ParseIP(res.bindAddr).To16()
|
|
if addr == nil {
|
|
return nil, fmt.Errorf("invalid IPv6 address for binding")
|
|
}
|
|
default:
|
|
return nil, fmt.Errorf("unsupported address type")
|
|
}
|
|
|
|
pkt = append(pkt, addr...)
|
|
port := make([]byte, 2)
|
|
binary.BigEndian.PutUint16(port, uint16(res.bindPort))
|
|
pkt = append(pkt, port...)
|
|
|
|
return pkt, nil
|
|
}
|