diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index df03a6bae..01fe8b094 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -18,6 +18,7 @@ "strconv" "strings" "sync" + "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" @@ -60,6 +61,7 @@ type Impl struct { } const nicID = 1 +const mtu = 1500 // Create creates and populates a new Impl. func Create(logf logger.Logf, tundev *tstun.TUN, e wgengine.Engine, mc *magicsock.Conn) (*Impl, error) { @@ -79,7 +81,6 @@ func Create(logf logger.Logf, tundev *tstun.TUN, e wgengine.Engine, mc *magicsoc NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6}, }) - const mtu = 1500 linkEP := channel.New(512, mtu, "") if tcpipProblem := ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil { return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem) @@ -390,18 +391,75 @@ func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) { ns.logf("Could not create endpoint, exiting") return } + localAddr, err := ep.GetLocalAddress() + if err != nil { + return + } + remoteAddr, err := ep.GetRemoteAddress() + if err != nil { + return + } c := gonet.NewUDPConn(ns.ipstack, &wq, ep) - go echoUDP(c) + go ns.forwardUDP(c, &wq, localAddr, remoteAddr) } -func echoUDP(c *gonet.UDPConn) { - buf := make([]byte, 1500) - for { - n, err := c.Read(buf) +func (ns *Impl) forwardUDP(client *gonet.UDPConn, wq *waiter.Queue, clientLocalAddr, clientRemoteAddr tcpip.FullAddress) { + port := clientLocalAddr.Port + ns.logf("[v2] netstack: forwarding incoming UDP connection on port %v", port) + backendLocalAddr := &net.UDPAddr{Port: int(clientRemoteAddr.Port)} + backendRemoteAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(port)} + backendConn, err := net.ListenUDP("udp4", backendLocalAddr) + if err != nil { + ns.logf("netstack: could not bind local port %v: %v, trying again with random port", clientRemoteAddr.Port, err) + backendConn, err = net.ListenUDP("udp4", nil) if err != nil { - break + ns.logf("netstack: could not connect to local UDP server on port %v: %v", port, err) + return } - c.Write(buf[:n]) } - c.Close() + ctx, cancel := context.WithCancel(context.Background()) + timer := time.AfterFunc(2*time.Minute, func() { + ns.logf("netstack: forwarder UDP connection on port %v closed", port) + cancel() + client.Close() + backendConn.Close() + }) + extend := func() { + timer.Reset(2 * time.Minute) + } + startPacketCopy(ctx, client, &net.UDPAddr{ + IP: net.ParseIP(clientRemoteAddr.Addr.String()), + Port: int(clientRemoteAddr.Port), + }, backendConn, ns.logf, extend) + startPacketCopy(ctx, backendConn, backendRemoteAddr, client, ns.logf, extend) + +} + +func startPacketCopy(ctx context.Context, dst net.PacketConn, dstAddr net.Addr, src net.PacketConn, logf logger.Logf, extend func()) { + go func() { + pkt := make([]byte, mtu) + for { + select { + case <-ctx.Done(): + return + default: + n, srcAddr, err := src.ReadFrom(pkt) + if err != nil { + if ctx.Err() == nil { + logf("read packet from %s failed: %v", srcAddr, err) + } + return + } + _, err = dst.WriteTo(pkt[:n], dstAddr) + if err != nil { + if ctx.Err() == nil { + logf("write packet to %s failed: %v", dstAddr, err) + } + return + } + logf("[v2] wrote UDP packet %s -> %s", srcAddr, dstAddr) + extend() + } + } + }() }