mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-29 13:05:46 +00:00
wgengine/netstack: add support for incoming UDP connections
Updates #504 Updates #707 Signed-off-by: Naman Sood <mail@nsood.in>
This commit is contained in:
parent
43b30e463c
commit
7325b5a7ba
@ -18,6 +18,7 @@
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
@ -60,6 +61,7 @@ type Impl struct {
|
||||
}
|
||||
|
||||
const nicID = 1
|
||||
const mtu = 1500
|
||||
|
||||
// Create creates and populates a new Impl.
|
||||
func Create(logf logger.Logf, tundev *tstun.TUN, e wgengine.Engine, mc *magicsock.Conn) (*Impl, error) {
|
||||
@ -79,7 +81,6 @@ func Create(logf logger.Logf, tundev *tstun.TUN, e wgengine.Engine, mc *magicsoc
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
|
||||
})
|
||||
const mtu = 1500
|
||||
linkEP := channel.New(512, mtu, "")
|
||||
if tcpipProblem := ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil {
|
||||
return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem)
|
||||
@ -390,18 +391,75 @@ func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) {
|
||||
ns.logf("Could not create endpoint, exiting")
|
||||
return
|
||||
}
|
||||
localAddr, err := ep.GetLocalAddress()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
remoteAddr, err := ep.GetRemoteAddress()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c := gonet.NewUDPConn(ns.ipstack, &wq, ep)
|
||||
go echoUDP(c)
|
||||
go ns.forwardUDP(c, &wq, localAddr, remoteAddr)
|
||||
}
|
||||
|
||||
func echoUDP(c *gonet.UDPConn) {
|
||||
buf := make([]byte, 1500)
|
||||
for {
|
||||
n, err := c.Read(buf)
|
||||
func (ns *Impl) forwardUDP(client *gonet.UDPConn, wq *waiter.Queue, clientLocalAddr, clientRemoteAddr tcpip.FullAddress) {
|
||||
port := clientLocalAddr.Port
|
||||
ns.logf("[v2] netstack: forwarding incoming UDP connection on port %v", port)
|
||||
backendLocalAddr := &net.UDPAddr{Port: int(clientRemoteAddr.Port)}
|
||||
backendRemoteAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(port)}
|
||||
backendConn, err := net.ListenUDP("udp4", backendLocalAddr)
|
||||
if err != nil {
|
||||
ns.logf("netstack: could not bind local port %v: %v, trying again with random port", clientRemoteAddr.Port, err)
|
||||
backendConn, err = net.ListenUDP("udp4", nil)
|
||||
if err != nil {
|
||||
break
|
||||
ns.logf("netstack: could not connect to local UDP server on port %v: %v", port, err)
|
||||
return
|
||||
}
|
||||
c.Write(buf[:n])
|
||||
}
|
||||
c.Close()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
timer := time.AfterFunc(2*time.Minute, func() {
|
||||
ns.logf("netstack: forwarder UDP connection on port %v closed", port)
|
||||
cancel()
|
||||
client.Close()
|
||||
backendConn.Close()
|
||||
})
|
||||
extend := func() {
|
||||
timer.Reset(2 * time.Minute)
|
||||
}
|
||||
startPacketCopy(ctx, client, &net.UDPAddr{
|
||||
IP: net.ParseIP(clientRemoteAddr.Addr.String()),
|
||||
Port: int(clientRemoteAddr.Port),
|
||||
}, backendConn, ns.logf, extend)
|
||||
startPacketCopy(ctx, backendConn, backendRemoteAddr, client, ns.logf, extend)
|
||||
|
||||
}
|
||||
|
||||
func startPacketCopy(ctx context.Context, dst net.PacketConn, dstAddr net.Addr, src net.PacketConn, logf logger.Logf, extend func()) {
|
||||
go func() {
|
||||
pkt := make([]byte, mtu)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
n, srcAddr, err := src.ReadFrom(pkt)
|
||||
if err != nil {
|
||||
if ctx.Err() == nil {
|
||||
logf("read packet from %s failed: %v", srcAddr, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
_, err = dst.WriteTo(pkt[:n], dstAddr)
|
||||
if err != nil {
|
||||
if ctx.Err() == nil {
|
||||
logf("write packet to %s failed: %v", dstAddr, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
logf("[v2] wrote UDP packet %s -> %s", srcAddr, dstAddr)
|
||||
extend()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user